1use std::collections::{BTreeMap, HashMap};
18use std::path::PathBuf;
19use std::sync::Arc;
20use std::time::Duration;
21
22use serde::{Deserialize, Serialize};
23use typed_builder::TypedBuilder;
24
25use tokio_util::sync::CancellationToken;
26
27use crate::callback::MessageCallback;
28use crate::hooks::HookMatcher;
29use crate::mcp::McpServers;
30use crate::permissions::CanUseToolCallback;
31
32pub const INPUT_FORMAT_STREAM_JSON: &str = "stream-json";
40
41pub const OUTPUT_FORMAT_STREAM_JSON: &str = "stream-json";
46
47#[derive(TypedBuilder)]
53pub struct ClientConfig {
54 #[builder(setter(into))]
57 pub prompt: String,
58
59 #[builder(default, setter(strip_option))]
63 pub cli_path: Option<PathBuf>,
64
65 #[builder(default, setter(strip_option))]
67 pub cwd: Option<PathBuf>,
68
69 #[builder(default, setter(strip_option, into))]
71 pub model: Option<String>,
72
73 #[builder(default, setter(strip_option, into))]
75 pub fallback_model: Option<String>,
76
77 #[builder(default, setter(strip_option))]
79 pub system_prompt: Option<SystemPrompt>,
80
81 #[builder(default, setter(strip_option))]
84 pub max_turns: Option<u32>,
85
86 #[builder(default, setter(strip_option))]
88 pub max_budget_usd: Option<f64>,
89
90 #[builder(default, setter(strip_option))]
92 pub max_thinking_tokens: Option<u32>,
93
94 #[builder(default)]
97 pub allowed_tools: Vec<String>,
98
99 #[builder(default)]
101 pub disallowed_tools: Vec<String>,
102
103 #[builder(default)]
106 pub permission_mode: PermissionMode,
107
108 #[builder(default, setter(strip_option))]
110 pub can_use_tool: Option<CanUseToolCallback>,
111
112 #[builder(default, setter(strip_option, into))]
115 pub resume: Option<String>,
116
117 #[builder(default)]
120 pub hooks: Vec<HookMatcher>,
121
122 #[builder(default)]
125 pub mcp_servers: McpServers,
126
127 #[builder(default, setter(strip_option))]
130 pub message_callback: Option<MessageCallback>,
131
132 #[builder(default)]
135 pub env: HashMap<String, String>,
136
137 #[builder(default)]
139 pub verbose: bool,
140
141 #[builder(default_code = r#""stream-json".into()"#, setter(into))]
150 pub output_format: String,
151
152 #[builder(default)]
177 pub extra_args: BTreeMap<String, Option<String>>,
178
179 #[builder(default_code = "Some(Duration::from_secs(30))")]
184 pub connect_timeout: Option<Duration>,
185
186 #[builder(default_code = "Some(Duration::from_secs(10))")]
191 pub close_timeout: Option<Duration>,
192
193 #[builder(default_code = "true")]
201 pub end_input_on_connect: bool,
202
203 #[builder(default)]
208 pub read_timeout: Option<Duration>,
209
210 #[builder(default_code = "Duration::from_secs(30)")]
215 pub default_hook_timeout: Duration,
216
217 #[builder(default_code = "Some(Duration::from_secs(5))")]
223 pub version_check_timeout: Option<Duration>,
224
225 #[builder(default_code = "Duration::from_secs(30)")]
231 pub control_request_timeout: Duration,
232
233 #[builder(default, setter(strip_option))]
239 pub cancellation_token: Option<CancellationToken>,
240
241 #[builder(default, setter(strip_option))]
244 pub stderr_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
245
246 #[builder(default, setter(strip_option, into))]
255 pub input_format: Option<String>,
256
257 #[builder(default, setter(strip_option, into))]
270 pub init_stdin_message: Option<String>,
271}
272
273impl std::fmt::Debug for ClientConfig {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("ClientConfig")
276 .field("prompt", &self.prompt)
277 .field("cli_path", &self.cli_path)
278 .field("cwd", &self.cwd)
279 .field("model", &self.model)
280 .field("permission_mode", &self.permission_mode)
281 .field("max_turns", &self.max_turns)
282 .field("max_budget_usd", &self.max_budget_usd)
283 .field("verbose", &self.verbose)
284 .field("output_format", &self.output_format)
285 .field("connect_timeout", &self.connect_timeout)
286 .field("close_timeout", &self.close_timeout)
287 .field("read_timeout", &self.read_timeout)
288 .field("default_hook_timeout", &self.default_hook_timeout)
289 .field("version_check_timeout", &self.version_check_timeout)
290 .field("control_request_timeout", &self.control_request_timeout)
291 .finish_non_exhaustive()
292 }
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
299#[serde(rename_all = "snake_case")]
300pub enum PermissionMode {
301 #[default]
303 Default,
304 AcceptEdits,
306 Plan,
308 BypassPermissions,
310}
311
312impl PermissionMode {
313 #[must_use]
315 pub fn as_cli_flag(&self) -> &'static str {
316 match self {
317 Self::Default => "default",
318 Self::AcceptEdits => "acceptEdits",
319 Self::Plan => "plan",
320 Self::BypassPermissions => "bypassPermissions",
321 }
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
329#[serde(tag = "type", rename_all = "snake_case")]
330pub enum SystemPrompt {
331 Text {
333 text: String,
335 },
336 Preset {
338 kind: String,
340 preset: String,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
344 append: Option<String>,
345 },
346}
347
348impl SystemPrompt {
349 #[must_use]
351 pub fn text(s: impl Into<String>) -> Self {
352 Self::Text { text: s.into() }
353 }
354
355 #[must_use]
357 pub fn preset(kind: impl Into<String>, preset: impl Into<String>) -> Self {
358 Self::Preset {
359 kind: kind.into(),
360 preset: preset.into(),
361 append: None,
362 }
363 }
364}
365
366impl ClientConfig {
369 pub fn validate(&self) -> crate::errors::Result<()> {
376 if let Some(ref cwd) = self.cwd {
377 if !cwd.exists() {
378 return Err(crate::errors::Error::Config(format!(
379 "working directory does not exist: {}",
380 cwd.display()
381 )));
382 }
383 if !cwd.is_dir() {
384 return Err(crate::errors::Error::Config(format!(
385 "working directory is not a directory: {}",
386 cwd.display()
387 )));
388 }
389 }
390
391 if let Some(ref msg) = self.init_stdin_message {
392 serde_json::from_str::<serde_json::Value>(msg).map_err(|e| {
393 crate::errors::Error::Config(format!("init_stdin_message is not valid JSON: {e}"))
394 })?;
395 }
396
397 if self.init_stdin_message.is_some()
398 && self.input_format.as_deref() != Some(INPUT_FORMAT_STREAM_JSON)
399 {
400 return Err(crate::errors::Error::Config(
401 "init_stdin_message requires input_format = \"stream-json\"".into(),
402 ));
403 }
404
405 if self.input_format.is_some() && self.extra_args.contains_key("input-format") {
406 return Err(crate::errors::Error::Config(
407 "input_format and extra_args[\"input-format\"] are mutually exclusive; use input_format".into(),
408 ));
409 }
410
411 Ok(())
412 }
413
414 #[must_use]
418 pub fn to_cli_args(&self) -> Vec<String> {
419 let mut args = vec!["--output-format".into(), self.output_format.clone()];
420
421 let uses_stream_input = self.input_format.as_deref() == Some(INPUT_FORMAT_STREAM_JSON);
425
426 if !uses_stream_input {
427 args.push("--print".into());
428 args.push(self.prompt.clone());
429 }
430
431 if let Some(ref fmt) = self.input_format {
432 args.push("--input-format".into());
433 args.push(fmt.clone());
434 }
435
436 if self.output_format == OUTPUT_FORMAT_STREAM_JSON && !self.verbose {
438 args.push("--verbose".into());
439 }
440
441 if let Some(model) = &self.model {
442 args.push("--model".into());
443 args.push(model.clone());
444 }
445
446 if let Some(fallback) = &self.fallback_model {
447 args.push("--fallback-model".into());
448 args.push(fallback.clone());
449 }
450
451 if let Some(turns) = self.max_turns {
452 args.push("--max-turns".into());
453 args.push(turns.to_string());
454 }
455
456 if let Some(budget) = self.max_budget_usd {
457 args.push("--max-budget-usd".into());
458 args.push(budget.to_string());
459 }
460
461 if let Some(thinking) = self.max_thinking_tokens {
462 args.push("--max-thinking-tokens".into());
463 args.push(thinking.to_string());
464 }
465
466 if self.permission_mode != PermissionMode::Default {
467 args.push("--permission-mode".into());
468 args.push(self.permission_mode.as_cli_flag().into());
469 }
470
471 if self.can_use_tool.is_some() {
476 args.push("--permission-prompt-tool".into());
477 args.push("stdio".into());
478 }
479
480 if let Some(resume) = &self.resume {
481 args.push("--resume".into());
482 args.push(resume.clone());
483 }
484
485 if self.verbose {
486 args.push("--verbose".into());
487 }
488
489 for tool in &self.allowed_tools {
490 args.push("--allowedTools".into());
491 args.push(tool.clone());
492 }
493
494 for tool in &self.disallowed_tools {
495 args.push("--disallowedTools".into());
496 args.push(tool.clone());
497 }
498
499 if !self.mcp_servers.is_empty() {
500 let json = serde_json::to_string(&self.mcp_servers)
501 .expect("McpServers serialization is infallible");
502 args.push("--mcp-servers".into());
503 args.push(json);
504 }
505
506 if let Some(prompt) = &self.system_prompt {
507 match prompt {
508 SystemPrompt::Text { text } => {
509 args.push("--system-prompt".into());
510 args.push(text.clone());
511 }
512 SystemPrompt::Preset { preset, append, .. } => {
513 args.push("--system-prompt-preset".into());
514 args.push(preset.clone());
515 if let Some(append_text) = append {
516 args.push("--append-system-prompt".into());
517 args.push(append_text.clone());
518 }
519 }
520 }
521 }
522
523 for (key, value) in &self.extra_args {
525 args.push(format!("--{key}"));
526 if let Some(v) = value {
527 args.push(v.clone());
528 }
529 }
530
531 args
532 }
533
534 #[must_use]
547 pub fn to_env(&self) -> HashMap<String, String> {
548 let mut env = HashMap::new();
549
550 env.insert(
552 "CLAUDE_CODE_SDK_ORIGINATOR".into(),
553 "claude_cli_sdk_rs".into(),
554 );
555 env.insert("TERM".into(), "dumb".into());
556
557 env.extend(self.env.clone());
559
560 if self.can_use_tool.is_some() || !self.hooks.is_empty() {
562 env.insert("CLAUDE_CODE_SDK_CONTROL_PORT".into(), "stdin".into());
563 }
564
565 env
566 }
567}
568
569#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn builder_minimal() {
577 let config = ClientConfig::builder().prompt("hello").build();
578 assert_eq!(config.prompt, "hello");
579 assert_eq!(config.output_format, "stream-json");
580 assert_eq!(config.permission_mode, PermissionMode::Default);
581 }
582
583 #[test]
584 fn builder_full() {
585 let config = ClientConfig::builder()
586 .prompt("test prompt")
587 .model("claude-opus-4-5")
588 .max_turns(5_u32)
589 .max_budget_usd(1.0_f64)
590 .permission_mode(PermissionMode::AcceptEdits)
591 .verbose(true)
592 .build();
593
594 assert_eq!(config.model.as_deref(), Some("claude-opus-4-5"));
595 assert_eq!(config.max_turns, Some(5));
596 assert_eq!(config.max_budget_usd, Some(1.0));
597 assert_eq!(config.permission_mode, PermissionMode::AcceptEdits);
598 assert!(config.verbose);
599 }
600
601 #[test]
602 fn to_cli_args_minimal() {
603 let config = ClientConfig::builder().prompt("hello").build();
604 let args = config.to_cli_args();
605 assert!(args.contains(&"--output-format".into()));
606 assert!(args.contains(&"stream-json".into()));
607 assert!(args.contains(&"--print".into()));
608 assert!(args.contains(&"hello".into()));
609 }
610
611 #[test]
612 fn to_cli_args_with_model_and_turns() {
613 let config = ClientConfig::builder()
614 .prompt("test")
615 .model("claude-sonnet-4-5")
616 .max_turns(10_u32)
617 .build();
618 let args = config.to_cli_args();
619 assert!(args.contains(&"--model".into()));
620 assert!(args.contains(&"claude-sonnet-4-5".into()));
621 assert!(args.contains(&"--max-turns".into()));
622 assert!(args.contains(&"10".into()));
623 }
624
625 #[test]
626 fn to_cli_args_with_permission_mode() {
627 let config = ClientConfig::builder()
628 .prompt("test")
629 .permission_mode(PermissionMode::BypassPermissions)
630 .build();
631 let args = config.to_cli_args();
632 assert!(args.contains(&"--permission-mode".into()));
633 assert!(args.contains(&"bypassPermissions".into()));
634 }
635
636 #[test]
637 fn to_cli_args_default_permission_mode_not_included() {
638 let config = ClientConfig::builder().prompt("test").build();
639 let args = config.to_cli_args();
640 assert!(!args.contains(&"--permission-mode".into()));
641 }
642
643 #[test]
644 fn to_cli_args_with_system_prompt_text() {
645 let config = ClientConfig::builder()
646 .prompt("test")
647 .system_prompt(SystemPrompt::text("You are a helpful assistant"))
648 .build();
649 let args = config.to_cli_args();
650 assert!(args.contains(&"--system-prompt".into()));
651 assert!(args.contains(&"You are a helpful assistant".into()));
652 }
653
654 #[test]
655 fn to_cli_args_with_mcp_servers() {
656 use crate::mcp::McpServerConfig;
657
658 let mut servers = McpServers::new();
659 servers.insert(
660 "fs".into(),
661 McpServerConfig::new("npx").with_args(["-y", "mcp-fs"]),
662 );
663
664 let config = ClientConfig::builder()
665 .prompt("test")
666 .mcp_servers(servers)
667 .build();
668 let args = config.to_cli_args();
669 assert!(args.contains(&"--mcp-servers".into()));
670 }
671
672 #[test]
673 fn to_env_without_callbacks() {
674 let config = ClientConfig::builder().prompt("test").build();
675 let env = config.to_env();
676 assert!(!env.contains_key("CLAUDE_CODE_SDK_CONTROL_PORT"));
677 }
678
679 #[test]
680 fn to_env_includes_originator_and_headless_defaults() {
681 let config = ClientConfig::builder().prompt("test").build();
682 let env = config.to_env();
683 assert_eq!(
684 env.get("CLAUDE_CODE_SDK_ORIGINATOR"),
685 Some(&"claude_cli_sdk_rs".into())
686 );
687 assert!(!env.contains_key("CI"), "CI must not be set by default");
688 assert_eq!(env.get("TERM"), Some(&"dumb".into()));
689 }
690
691 #[test]
692 fn to_env_user_env_overrides_defaults() {
693 let config = ClientConfig::builder()
694 .prompt("test")
695 .env(HashMap::from([("TERM".into(), "xterm-256color".into())]))
696 .build();
697 let env = config.to_env();
698 assert_eq!(env.get("TERM"), Some(&"xterm-256color".into()));
700 assert_eq!(
702 env.get("CLAUDE_CODE_SDK_ORIGINATOR"),
703 Some(&"claude_cli_sdk_rs".into())
704 );
705 }
706
707 #[test]
708 fn to_env_with_hooks_enables_control_port() {
709 use crate::hooks::{HookCallback, HookEvent, HookMatcher, HookOutput};
710 let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
711 let config = ClientConfig::builder()
712 .prompt("test")
713 .hooks(vec![HookMatcher::new(HookEvent::PreToolUse, cb)])
714 .build();
715 let env = config.to_env();
716 assert_eq!(
717 env.get("CLAUDE_CODE_SDK_CONTROL_PORT"),
718 Some(&"stdin".into())
719 );
720 }
721
722 #[test]
723 fn permission_mode_serde_round_trip() {
724 let modes = [
725 PermissionMode::Default,
726 PermissionMode::AcceptEdits,
727 PermissionMode::Plan,
728 PermissionMode::BypassPermissions,
729 ];
730 for mode in modes {
731 let json = serde_json::to_string(&mode).unwrap();
732 let decoded: PermissionMode = serde_json::from_str(&json).unwrap();
733 assert_eq!(mode, decoded);
734 }
735 }
736
737 #[test]
738 fn system_prompt_text_round_trip() {
739 let sp = SystemPrompt::text("You are helpful");
740 let json = serde_json::to_string(&sp).unwrap();
741 let decoded: SystemPrompt = serde_json::from_str(&json).unwrap();
742 assert_eq!(sp, decoded);
743 }
744
745 #[test]
746 fn system_prompt_preset_round_trip() {
747 let sp = SystemPrompt::Preset {
748 kind: "custom".into(),
749 preset: "coding".into(),
750 append: Some("Also be concise.".into()),
751 };
752 let json = serde_json::to_string(&sp).unwrap();
753 let decoded: SystemPrompt = serde_json::from_str(&json).unwrap();
754 assert_eq!(sp, decoded);
755 }
756
757 #[test]
758 fn debug_does_not_panic() {
759 let config = ClientConfig::builder().prompt("test").build();
760 let _ = format!("{config:?}");
761 }
762
763 #[test]
764 fn to_cli_args_with_allowed_tools() {
765 let config = ClientConfig::builder()
766 .prompt("test")
767 .allowed_tools(vec!["bash".into(), "read_file".into()])
768 .build();
769 let args = config.to_cli_args();
770 let allowed_count = args.iter().filter(|a| *a == "--allowedTools").count();
771 assert_eq!(allowed_count, 2);
772 }
773
774 #[test]
775 fn to_cli_args_with_extra_args_boolean_flag() {
776 let config = ClientConfig::builder()
777 .prompt("test")
778 .extra_args(BTreeMap::from([("replay-user-messages".into(), None)]))
779 .build();
780 let args = config.to_cli_args();
781 assert!(args.contains(&"--replay-user-messages".into()));
782 }
783
784 #[test]
785 fn to_cli_args_with_extra_args_valued_flag() {
786 let config = ClientConfig::builder()
787 .prompt("test")
788 .extra_args(BTreeMap::from([(
789 "context-window".into(),
790 Some("200000".into()),
791 )]))
792 .build();
793 let args = config.to_cli_args();
794 let idx = args.iter().position(|a| a == "--context-window").unwrap();
795 assert_eq!(args[idx + 1], "200000");
796 }
797
798 #[test]
799 fn builder_timeout_defaults() {
800 let config = ClientConfig::builder().prompt("test").build();
801 assert_eq!(config.connect_timeout, Some(Duration::from_secs(30)));
802 assert_eq!(config.close_timeout, Some(Duration::from_secs(10)));
803 assert_eq!(config.read_timeout, None);
804 assert_eq!(config.default_hook_timeout, Duration::from_secs(30));
805 assert_eq!(config.version_check_timeout, Some(Duration::from_secs(5)));
806 }
807
808 #[test]
809 fn builder_custom_timeouts() {
810 let config = ClientConfig::builder()
811 .prompt("test")
812 .connect_timeout(Some(Duration::from_secs(60)))
813 .close_timeout(Some(Duration::from_secs(20)))
814 .read_timeout(Some(Duration::from_secs(120)))
815 .default_hook_timeout(Duration::from_secs(10))
816 .version_check_timeout(Some(Duration::from_secs(15)))
817 .build();
818 assert_eq!(config.connect_timeout, Some(Duration::from_secs(60)));
819 assert_eq!(config.close_timeout, Some(Duration::from_secs(20)));
820 assert_eq!(config.read_timeout, Some(Duration::from_secs(120)));
821 assert_eq!(config.default_hook_timeout, Duration::from_secs(10));
822 assert_eq!(config.version_check_timeout, Some(Duration::from_secs(15)));
823 }
824
825 #[test]
826 fn builder_disable_connect_timeout() {
827 let config = ClientConfig::builder()
828 .prompt("test")
829 .connect_timeout(None::<Duration>)
830 .build();
831 assert_eq!(config.connect_timeout, None);
832 }
833
834 #[test]
835 fn builder_cancellation_token() {
836 let token = CancellationToken::new();
837 let config = ClientConfig::builder()
838 .prompt("test")
839 .cancellation_token(token.clone())
840 .build();
841 assert!(config.cancellation_token.is_some());
842 }
843
844 #[test]
845 fn builder_cancellation_token_default_is_none() {
846 let config = ClientConfig::builder().prompt("test").build();
847 assert!(config.cancellation_token.is_none());
848 }
849
850 #[test]
851 fn to_cli_args_with_resume() {
852 let config = ClientConfig::builder()
853 .prompt("test")
854 .resume("session-123")
855 .build();
856 let args = config.to_cli_args();
857 assert!(args.contains(&"--resume".into()));
858 assert!(args.contains(&"session-123".into()));
859 }
860
861 #[test]
862 fn to_cli_args_stream_input_format_omits_print() {
863 let config = ClientConfig::builder()
864 .prompt("ignored")
865 .input_format(INPUT_FORMAT_STREAM_JSON)
866 .build();
867 let args = config.to_cli_args();
868 assert!(
869 !args.contains(&"--print".into()),
870 "--print must be absent in stream-json input mode"
871 );
872 assert!(args.contains(&"--input-format".into()));
873 let idx = args.iter().position(|a| a == "--input-format").unwrap();
874 assert_eq!(args[idx + 1], INPUT_FORMAT_STREAM_JSON);
875 }
876
877 #[test]
878 fn to_cli_args_input_format_emitted() {
879 let config = ClientConfig::builder()
880 .prompt("test")
881 .input_format("custom-format")
882 .build();
883 let args = config.to_cli_args();
884 assert!(args.contains(&"--input-format".into()));
885 let idx = args.iter().position(|a| a == "--input-format").unwrap();
886 assert_eq!(args[idx + 1], "custom-format");
887 }
888
889 #[test]
890 fn validate_init_stdin_message_valid_json() {
891 let config = ClientConfig::builder()
892 .prompt("ignored")
893 .input_format(INPUT_FORMAT_STREAM_JSON)
894 .init_stdin_message(r#"{"type":"user","message":{"role":"user","content":"hello"}}"#)
895 .build();
896 assert!(config.validate().is_ok());
897 }
898
899 #[test]
900 fn validate_init_stdin_message_invalid_json() {
901 let config = ClientConfig::builder()
902 .prompt("ignored")
903 .input_format(INPUT_FORMAT_STREAM_JSON)
904 .init_stdin_message("not valid json {")
905 .build();
906 let err = config.validate().unwrap_err();
907 assert!(
908 matches!(err, crate::errors::Error::Config(ref msg) if msg.contains("not valid JSON")),
909 "expected Config error about JSON validity, got: {err:?}"
910 );
911 }
912
913 #[test]
914 fn validate_init_stdin_message_without_input_format() {
915 let config = ClientConfig::builder()
916 .prompt("ignored")
917 .init_stdin_message(r#"{"type":"user"}"#)
918 .build();
919 let err = config.validate().unwrap_err();
920 assert!(
921 matches!(err, crate::errors::Error::Config(ref msg) if msg.contains("input_format")),
922 "expected Config error about missing input_format, got: {err:?}"
923 );
924 }
925}