1use std::io;
2use std::path::{Path, PathBuf};
3use std::process::{Command, ExitStatus, Stdio};
4
5use crate::fs_util;
6
7#[derive(Debug, Clone, PartialEq)]
9pub struct Snippet {
10 pub name: String,
11 pub command: String,
12 pub description: String,
13}
14
15pub struct SnippetResult {
17 pub status: ExitStatus,
18 pub stdout: String,
19 pub stderr: String,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct SnippetStore {
25 pub snippets: Vec<Snippet>,
26 pub path_override: Option<PathBuf>,
28}
29
30fn config_path() -> Option<PathBuf> {
31 dirs::home_dir().map(|h| h.join(".purple/snippets"))
32}
33
34impl SnippetStore {
35 pub fn load() -> Self {
38 let path = match config_path() {
39 Some(p) => p,
40 None => return Self::default(),
41 };
42 let content = match std::fs::read_to_string(&path) {
43 Ok(c) => c,
44 Err(e) if e.kind() == io::ErrorKind::NotFound => return Self::default(),
45 Err(e) => {
46 log::warn!("[config] Could not read {}: {}", path.display(), e);
47 return Self::default();
48 }
49 };
50 Self::parse(&content)
51 }
52
53 pub fn parse(content: &str) -> Self {
55 let mut snippets = Vec::new();
56 let mut current: Option<Snippet> = None;
57
58 for line in content.lines() {
59 let trimmed = line.trim();
60 if trimmed.is_empty() || trimmed.starts_with('#') {
61 continue;
62 }
63 if trimmed.starts_with('[') && trimmed.ends_with(']') {
64 if let Some(snippet) = current.take() {
65 if !snippet.command.is_empty()
66 && !snippets.iter().any(|s: &Snippet| s.name == snippet.name)
67 {
68 snippets.push(snippet);
69 }
70 }
71 let name = trimmed[1..trimmed.len() - 1].trim().to_string();
72 if snippets.iter().any(|s| s.name == name) {
73 current = None;
74 continue;
75 }
76 current = Some(Snippet {
77 name,
78 command: String::new(),
79 description: String::new(),
80 });
81 } else if let Some(ref mut snippet) = current {
82 if let Some((key, value)) = trimmed.split_once('=') {
83 let key = key.trim();
84 let value = value.trim_start().to_string();
87 match key {
88 "command" => snippet.command = value,
89 "description" => snippet.description = value,
90 _ => {}
91 }
92 }
93 }
94 }
95 if let Some(snippet) = current {
96 if !snippet.command.is_empty() && !snippets.iter().any(|s| s.name == snippet.name) {
97 snippets.push(snippet);
98 }
99 }
100 Self {
101 snippets,
102 path_override: None,
103 }
104 }
105
106 pub fn save(&self) -> io::Result<()> {
108 if crate::demo_flag::is_demo() {
109 return Ok(());
110 }
111 let path = match &self.path_override {
112 Some(p) => p.clone(),
113 None => match config_path() {
114 Some(p) => p,
115 None => {
116 return Err(io::Error::new(
117 io::ErrorKind::NotFound,
118 "Could not determine home directory",
119 ));
120 }
121 },
122 };
123
124 let mut content = String::new();
125 for (i, snippet) in self.snippets.iter().enumerate() {
126 if i > 0 {
127 content.push('\n');
128 }
129 content.push_str(&format!("[{}]\n", snippet.name));
130 content.push_str(&format!("command={}\n", snippet.command));
131 if !snippet.description.is_empty() {
132 content.push_str(&format!("description={}\n", snippet.description));
133 }
134 }
135
136 fs_util::atomic_write(&path, content.as_bytes())
137 }
138
139 pub fn get(&self, name: &str) -> Option<&Snippet> {
141 self.snippets.iter().find(|s| s.name == name)
142 }
143
144 pub fn set(&mut self, snippet: Snippet) {
146 if let Some(existing) = self.snippets.iter_mut().find(|s| s.name == snippet.name) {
147 *existing = snippet;
148 } else {
149 self.snippets.push(snippet);
150 }
151 }
152
153 pub fn remove(&mut self, name: &str) {
155 self.snippets.retain(|s| s.name != name);
156 }
157}
158
159pub fn validate_name(name: &str) -> Result<(), String> {
162 if name.trim().is_empty() {
163 return Err("Snippet name cannot be empty.".to_string());
164 }
165 if name != name.trim() {
166 return Err("Snippet name cannot have leading or trailing whitespace.".to_string());
167 }
168 if name.contains('#') || name.contains('[') || name.contains(']') {
169 return Err("Snippet name cannot contain #, [ or ].".to_string());
170 }
171 if name.contains(|c: char| c.is_control()) {
172 return Err("Snippet name cannot contain control characters.".to_string());
173 }
174 Ok(())
175}
176
177pub fn validate_command(command: &str) -> Result<(), String> {
179 if command.trim().is_empty() {
180 return Err("Command cannot be empty.".to_string());
181 }
182 if command.contains(|c: char| c.is_control() && c != '\t') {
183 return Err("Command cannot contain control characters.".to_string());
184 }
185 Ok(())
186}
187
188#[derive(Debug, Clone, PartialEq)]
194pub struct SnippetParam {
195 pub name: String,
196 pub default: Option<String>,
197}
198
199pub fn shell_escape(s: &str) -> String {
202 format!("'{}'", s.replace('\'', "'\\''"))
203}
204
205pub fn parse_params(command: &str) -> Vec<SnippetParam> {
208 let mut params = Vec::new();
209 let mut seen = std::collections::HashSet::new();
210 let bytes = command.as_bytes();
211 let len = bytes.len();
212 let mut i = 0;
213 while i + 3 < len {
214 if bytes[i] == b'{' && bytes.get(i + 1) == Some(&b'{') {
215 if let Some(end) = command[i + 2..].find("}}") {
216 let inner = &command[i + 2..i + 2 + end];
217 let (name, default) = if let Some((n, d)) = inner.split_once(':') {
218 (n.to_string(), Some(d.to_string()))
219 } else {
220 (inner.to_string(), None)
221 };
222 if validate_param_name(&name).is_ok() && !seen.contains(&name) && params.len() < 20
223 {
224 seen.insert(name.clone());
225 params.push(SnippetParam { name, default });
226 }
227 i = i + 2 + end + 2;
228 continue;
229 }
230 }
231 i += 1;
232 }
233 params
234}
235
236pub fn validate_param_name(name: &str) -> Result<(), String> {
239 if name.is_empty() {
240 return Err("Parameter name cannot be empty.".to_string());
241 }
242 if !name
243 .chars()
244 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
245 {
246 return Err(format!(
247 "Parameter name '{}' contains invalid characters.",
248 name
249 ));
250 }
251 Ok(())
252}
253
254pub fn substitute_params(
257 command: &str,
258 values: &std::collections::HashMap<String, String>,
259) -> String {
260 let mut result = String::with_capacity(command.len());
261 let bytes = command.as_bytes();
262 let len = bytes.len();
263 let mut i = 0;
264 while i < len {
265 if i + 3 < len && bytes[i] == b'{' && bytes[i + 1] == b'{' {
266 if let Some(end) = command[i + 2..].find("}}") {
267 let inner = &command[i + 2..i + 2 + end];
268 let (name, default) = if let Some((n, d)) = inner.split_once(':') {
269 (n, Some(d))
270 } else {
271 (inner, None)
272 };
273 let value = values
274 .get(name)
275 .filter(|v| !v.is_empty())
276 .map(|v| v.as_str())
277 .or(default)
278 .unwrap_or("");
279 result.push_str(&shell_escape(value));
280 i = i + 2 + end + 2;
281 continue;
282 }
283 }
284 let ch = command[i..].chars().next().unwrap();
286 result.push(ch);
287 i += ch.len_utf8();
288 }
289 result
290}
291
292pub fn sanitize_output(input: &str) -> String {
299 let mut out = String::with_capacity(input.len());
300 let mut chars = input.chars().peekable();
301 while let Some(c) = chars.next() {
302 match c {
303 '\x1b' => {
304 match chars.peek() {
305 Some('[') => {
306 chars.next();
307 while let Some(&ch) = chars.peek() {
309 chars.next();
310 if ('\x40'..='\x7e').contains(&ch) {
311 break;
312 }
313 }
314 }
315 Some(']') | Some('P') | Some('X') | Some('^') | Some('_') => {
316 chars.next();
317 consume_until_st(&mut chars);
319 }
320 _ => {
321 chars.next();
323 }
324 }
325 }
326 c if ('\u{0080}'..='\u{009F}').contains(&c) => {
327 }
329 c if c.is_control() && c != '\n' && c != '\t' => {
330 }
332 _ => out.push(c),
333 }
334 }
335 out
336}
337
338fn consume_until_st(chars: &mut std::iter::Peekable<std::str::Chars<'_>>) {
340 while let Some(&ch) = chars.peek() {
341 if ch == '\x07' {
342 chars.next();
343 break;
344 }
345 if ch == '\x1b' {
346 chars.next();
347 if chars.peek() == Some(&'\\') {
348 chars.next();
349 }
350 break;
351 }
352 chars.next();
353 }
354}
355
356const MAX_OUTPUT_LINES: usize = 10_000;
363
364pub enum SnippetEvent {
367 HostDone {
368 run_id: u64,
369 alias: String,
370 stdout: String,
371 stderr: String,
372 exit_code: Option<i32>,
373 },
374 Progress {
375 run_id: u64,
376 completed: usize,
377 total: usize,
378 },
379 AllDone {
380 run_id: u64,
381 },
382}
383
384pub struct ChildGuard {
387 inner: std::sync::Mutex<Option<std::process::Child>>,
388 pgid: i32,
389}
390
391impl ChildGuard {
392 fn new(child: std::process::Child) -> Self {
393 let pgid = i32::try_from(child.id()).unwrap_or(-1);
397 Self {
398 inner: std::sync::Mutex::new(Some(child)),
399 pgid,
400 }
401 }
402}
403
404impl Drop for ChildGuard {
405 fn drop(&mut self) {
406 let mut lock = self.inner.lock().unwrap_or_else(|e| e.into_inner());
407 if let Some(ref mut child) = *lock {
408 if let Ok(Some(_)) = child.try_wait() {
410 return;
411 }
412 #[cfg(unix)]
418 unsafe {
419 libc::kill(-self.pgid, libc::SIGTERM);
420 }
421 let deadline = std::time::Instant::now() + std::time::Duration::from_millis(500);
423 loop {
424 if let Ok(Some(_)) = child.try_wait() {
425 return;
426 }
427 if std::time::Instant::now() >= deadline {
428 break;
429 }
430 std::thread::sleep(std::time::Duration::from_millis(50));
431 }
432 #[cfg(unix)]
434 unsafe {
435 libc::kill(-self.pgid, libc::SIGKILL);
436 }
437 let _ = child.kill();
439 let _ = child.wait();
440 }
441 }
442}
443
444fn read_pipe_capped<R: io::Read>(reader: R) -> String {
447 use io::BufRead;
448 let mut reader = io::BufReader::new(reader);
449 let mut output = String::new();
450 let mut line_count = 0;
451 let mut capped = false;
452 let mut buf = Vec::new();
453 loop {
454 buf.clear();
455 match reader.read_until(b'\n', &mut buf) {
456 Ok(0) => break, Ok(_) => {
458 if !capped {
459 if line_count < MAX_OUTPUT_LINES {
460 if line_count > 0 {
461 output.push('\n');
462 }
463 if buf.last() == Some(&b'\n') {
465 buf.pop();
466 if buf.last() == Some(&b'\r') {
467 buf.pop();
468 }
469 }
470 output.push_str(&String::from_utf8_lossy(&buf));
472 line_count += 1;
473 } else {
474 output.push_str("\n[Output truncated at 10,000 lines]");
475 capped = true;
476 }
477 }
478 }
480 Err(_) => break,
481 }
482 }
483 output
484}
485
486fn base_ssh_command(
490 alias: &str,
491 config_path: &Path,
492 command: &str,
493 askpass: Option<&str>,
494 bw_session: Option<&str>,
495 has_active_tunnel: bool,
496) -> Command {
497 let mut cmd = Command::new("ssh");
498 cmd.arg("-F")
499 .arg(config_path)
500 .arg("-o")
501 .arg("ConnectTimeout=10")
502 .arg("-o")
503 .arg("ControlMaster=no")
504 .arg("-o")
505 .arg("ControlPath=none");
506
507 if has_active_tunnel {
508 cmd.arg("-o").arg("ClearAllForwardings=yes");
509 }
510
511 cmd.arg("--").arg(alias).arg(command);
512
513 if askpass.is_some() {
514 crate::askpass_env::configure_ssh_command(&mut cmd, alias, config_path);
515 }
516
517 if let Some(token) = bw_session {
518 cmd.env("BW_SESSION", token);
519 }
520
521 cmd
522}
523
524fn build_snippet_command(
526 alias: &str,
527 config_path: &Path,
528 command: &str,
529 askpass: Option<&str>,
530 bw_session: Option<&str>,
531 has_active_tunnel: bool,
532) -> Command {
533 let mut cmd = base_ssh_command(
534 alias,
535 config_path,
536 command,
537 askpass,
538 bw_session,
539 has_active_tunnel,
540 );
541 cmd.stdin(Stdio::null())
542 .stdout(Stdio::piped())
543 .stderr(Stdio::piped());
544
545 #[cfg(unix)]
548 unsafe {
549 use std::os::unix::process::CommandExt;
550 cmd.pre_exec(|| {
551 libc::setpgid(0, 0);
552 Ok(())
553 });
554 }
555
556 cmd
557}
558
559#[allow(clippy::too_many_arguments)]
561fn execute_host(
562 run_id: u64,
563 alias: &str,
564 config_path: &Path,
565 command: &str,
566 askpass: Option<&str>,
567 bw_session: Option<&str>,
568 has_active_tunnel: bool,
569 tx: &std::sync::mpsc::Sender<SnippetEvent>,
570) -> Option<std::sync::Arc<ChildGuard>> {
571 let mut cmd = build_snippet_command(
572 alias,
573 config_path,
574 command,
575 askpass,
576 bw_session,
577 has_active_tunnel,
578 );
579
580 match cmd.spawn() {
581 Ok(child) => {
582 let guard = std::sync::Arc::new(ChildGuard::new(child));
583
584 let stdout_pipe = {
586 let mut lock = guard.inner.lock().unwrap_or_else(|e| e.into_inner());
587 lock.as_mut().and_then(|c| c.stdout.take())
588 };
589 let stderr_pipe = {
590 let mut lock = guard.inner.lock().unwrap_or_else(|e| e.into_inner());
591 lock.as_mut().and_then(|c| c.stderr.take())
592 };
593
594 let stdout_handle = std::thread::spawn(move || match stdout_pipe {
596 Some(pipe) => read_pipe_capped(pipe),
597 None => String::new(),
598 });
599 let stderr_handle = std::thread::spawn(move || match stderr_pipe {
600 Some(pipe) => read_pipe_capped(pipe),
601 None => String::new(),
602 });
603
604 let stdout_text = stdout_handle.join().unwrap_or_default();
606 let stderr_text = stderr_handle.join().unwrap_or_default();
607
608 let exit_code = {
611 let mut lock = guard.inner.lock().unwrap_or_else(|e| e.into_inner());
612 let status = lock.as_mut().and_then(|c| c.wait().ok());
613 let _ = lock.take(); status.and_then(|s| {
615 #[cfg(unix)]
616 {
617 use std::os::unix::process::ExitStatusExt;
618 s.code().or_else(|| s.signal().map(|sig| 128 + sig))
619 }
620 #[cfg(not(unix))]
621 {
622 s.code()
623 }
624 })
625 };
626
627 let _ = tx.send(SnippetEvent::HostDone {
628 run_id,
629 alias: alias.to_string(),
630 stdout: sanitize_output(&stdout_text),
631 stderr: sanitize_output(&stderr_text),
632 exit_code,
633 });
634
635 Some(guard)
636 }
637 Err(e) => {
638 let _ = tx.send(SnippetEvent::HostDone {
639 run_id,
640 alias: alias.to_string(),
641 stdout: String::new(),
642 stderr: format!("Failed to launch ssh: {}", e),
643 exit_code: None,
644 });
645 None
646 }
647 }
648}
649
650#[allow(clippy::too_many_arguments)]
653pub fn spawn_snippet_execution(
654 run_id: u64,
655 askpass_map: Vec<(String, Option<String>)>,
656 config_path: PathBuf,
657 command: String,
658 bw_session: Option<String>,
659 tunnel_aliases: std::collections::HashSet<String>,
660 cancel: std::sync::Arc<std::sync::atomic::AtomicBool>,
661 tx: std::sync::mpsc::Sender<SnippetEvent>,
662 parallel: bool,
663) {
664 let total = askpass_map.len();
665 let max_concurrent: usize = 20;
666
667 std::thread::Builder::new()
668 .name("snippet-coordinator".into())
669 .spawn(move || {
670 let guards: std::sync::Arc<std::sync::Mutex<Vec<std::sync::Arc<ChildGuard>>>> =
671 std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
672
673 if parallel && total > 1 {
674 let (slot_tx, slot_rx) = std::sync::mpsc::channel::<()>();
676 for _ in 0..max_concurrent.min(total) {
677 let _ = slot_tx.send(());
678 }
679
680 let completed = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
681 let mut worker_handles = Vec::new();
682
683 for (alias, askpass) in askpass_map {
684 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
685 break;
686 }
687
688 loop {
690 match slot_rx.recv_timeout(std::time::Duration::from_millis(100)) {
691 Ok(()) => break,
692 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
693 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
694 break;
695 }
696 }
697 Err(_) => break, }
699 }
700
701 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
702 break;
703 }
704
705 let config_path = config_path.clone();
706 let command = command.clone();
707 let bw_session = bw_session.clone();
708 let has_tunnel = tunnel_aliases.contains(&alias);
709 let tx = tx.clone();
710 let slot_tx = slot_tx.clone();
711 let guards = guards.clone();
712 let completed = completed.clone();
713 let total = total;
714
715 let handle = std::thread::spawn(move || {
716 struct SlotRelease(Option<std::sync::mpsc::Sender<()>>);
718 impl Drop for SlotRelease {
719 fn drop(&mut self) {
720 if let Some(tx) = self.0.take() {
721 let _ = tx.send(());
722 }
723 }
724 }
725 let _slot = SlotRelease(Some(slot_tx));
726
727 let guard = execute_host(
728 run_id,
729 &alias,
730 &config_path,
731 &command,
732 askpass.as_deref(),
733 bw_session.as_deref(),
734 has_tunnel,
735 &tx,
736 );
737
738 if let Some(g) = guard {
740 guards.lock().unwrap_or_else(|e| e.into_inner()).push(g);
741 }
742
743 let c = completed.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
744 let _ = tx.send(SnippetEvent::Progress {
745 run_id,
746 completed: c,
747 total,
748 });
749 });
751 worker_handles.push(handle);
752 }
753
754 for handle in worker_handles {
756 let _ = handle.join();
757 }
758 } else {
759 for (i, (alias, askpass)) in askpass_map.into_iter().enumerate() {
761 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
762 break;
763 }
764
765 let has_tunnel = tunnel_aliases.contains(&alias);
766 let guard = execute_host(
767 run_id,
768 &alias,
769 &config_path,
770 &command,
771 askpass.as_deref(),
772 bw_session.as_deref(),
773 has_tunnel,
774 &tx,
775 );
776
777 if let Some(g) = guard {
778 guards.lock().unwrap_or_else(|e| e.into_inner()).push(g);
779 }
780
781 let _ = tx.send(SnippetEvent::Progress {
782 run_id,
783 completed: i + 1,
784 total,
785 });
786 }
787 }
788
789 let _ = tx.send(SnippetEvent::AllDone { run_id });
790 })
792 .expect("failed to spawn snippet coordinator");
793}
794
795pub fn run_snippet(
800 alias: &str,
801 config_path: &Path,
802 command: &str,
803 askpass: Option<&str>,
804 bw_session: Option<&str>,
805 capture: bool,
806 has_active_tunnel: bool,
807) -> anyhow::Result<SnippetResult> {
808 let mut cmd = base_ssh_command(
809 alias,
810 config_path,
811 command,
812 askpass,
813 bw_session,
814 has_active_tunnel,
815 );
816 cmd.stdin(Stdio::inherit());
817
818 if capture {
819 cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
820 } else {
821 cmd.stdout(Stdio::inherit()).stderr(Stdio::inherit());
822 }
823
824 if capture {
825 let output = cmd
826 .output()
827 .map_err(|e| anyhow::anyhow!("Failed to run ssh for '{}': {}", alias, e))?;
828
829 Ok(SnippetResult {
830 status: output.status,
831 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
832 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
833 })
834 } else {
835 let status = cmd
836 .status()
837 .map_err(|e| anyhow::anyhow!("Failed to run ssh for '{}': {}", alias, e))?;
838
839 Ok(SnippetResult {
840 status,
841 stdout: String::new(),
842 stderr: String::new(),
843 })
844 }
845}
846
847#[cfg(test)]
848#[path = "snippet_tests.rs"]
849mod tests;