1use crate::{
7 errors::Result,
8 transport::InputMessage,
9 types::{ClaudeCodeOptions, Message, PermissionMode},
10};
11use futures::stream::Stream;
12use std::pin::Pin;
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tracing::{debug, info, warn};
16
17pub enum QueryInput {
19 Text(String),
21 Stream(Pin<Box<dyn Stream<Item = InputMessage> + Send>>),
23}
24
25impl From<String> for QueryInput {
26 fn from(s: String) -> Self {
27 QueryInput::Text(s)
28 }
29}
30
31impl From<&str> for QueryInput {
32 fn from(s: &str) -> Self {
33 QueryInput::Text(s.to_string())
34 }
35}
36
37pub async fn query(
116 prompt: impl Into<QueryInput>,
117 options: Option<ClaudeCodeOptions>,
118) -> Result<impl Stream<Item = Result<Message>>> {
119 let options = options.unwrap_or_default();
120 let prompt = prompt.into();
121
122 unsafe {
124 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
125 }
126
127 match prompt {
128 QueryInput::Text(text) => {
129 query_print_mode(text, options).await
131 }
132 QueryInput::Stream(_stream) => {
133 Err(crate::SdkError::NotSupported {
136 feature: "Streaming input mode not yet implemented".into(),
137 })
138 }
139 }
140}
141
142#[allow(deprecated)]
144async fn query_print_mode(
145 prompt: String,
146 options: ClaudeCodeOptions,
147) -> Result<impl Stream<Item = Result<Message>>> {
148 use std::sync::Arc;
149 use tokio::io::{AsyncBufReadExt, BufReader};
150 use tokio::process::Command;
151 use tokio::sync::Mutex;
152
153 let cli_path = crate::transport::subprocess::find_claude_cli()?;
154 let mut cmd = Command::new(&cli_path);
155
156 cmd.arg("--output-format").arg("stream-json");
158 cmd.arg("--verbose");
159
160 if let Some(ref prompt_v2) = options.system_prompt_v2 {
164 match prompt_v2 {
165 crate::types::SystemPrompt::String(s) => {
166 cmd.arg("--system-prompt").arg(s);
167 }
168 crate::types::SystemPrompt::Preset { append, .. } => {
169 if let Some(append_text) = append {
170 cmd.arg("--append-system-prompt").arg(append_text);
171 }
172 }
173 }
174 } else {
175 #[allow(deprecated)]
176 match options.system_prompt.as_deref() {
177 Some(prompt) => {
178 cmd.arg("--system-prompt").arg(prompt);
179 }
180 None => {
181 cmd.arg("--system-prompt").arg("");
182 }
183 }
184
185 #[allow(deprecated)]
186 if let Some(ref append_prompt) = options.append_system_prompt {
187 cmd.arg("--append-system-prompt").arg(append_prompt);
188 }
189 }
190
191 if !options.allowed_tools.is_empty() {
192 cmd.arg("--allowedTools")
193 .arg(options.allowed_tools.join(","));
194 }
195
196 if let Some(max_turns) = options.max_turns {
197 cmd.arg("--max-turns").arg(max_turns.to_string());
198 }
199
200 if let Some(max_thinking_tokens) = options.max_thinking_tokens {
202 if max_thinking_tokens > 0 {
203 cmd.arg("--max-thinking-tokens")
204 .arg(max_thinking_tokens.to_string());
205 }
206 }
207
208 if !options.disallowed_tools.is_empty() {
209 cmd.arg("--disallowedTools")
210 .arg(options.disallowed_tools.join(","));
211 }
212
213 if let Some(ref model) = options.model {
214 cmd.arg("--model").arg(model);
215 }
216
217 if let Some(ref tool_name) = options.permission_prompt_tool_name {
218 cmd.arg("--permission-prompt-tool").arg(tool_name);
219 }
220
221 match options.permission_mode {
222 PermissionMode::Default => {
223 cmd.arg("--permission-mode").arg("default");
224 }
225 PermissionMode::AcceptEdits => {
226 cmd.arg("--permission-mode").arg("acceptEdits");
227 }
228 PermissionMode::Plan => {
229 cmd.arg("--permission-mode").arg("plan");
230 }
231 PermissionMode::BypassPermissions => {
232 cmd.arg("--permission-mode").arg("bypassPermissions");
233 }
234 }
235
236 if options.continue_conversation {
237 cmd.arg("--continue");
238 }
239
240 if let Some(ref resume_id) = options.resume {
241 cmd.arg("--resume").arg(resume_id);
242 }
243
244 if !options.mcp_servers.is_empty() {
245 let mcp_config = serde_json::json!({
246 "mcpServers": options.mcp_servers
247 });
248 cmd.arg("--mcp-config").arg(mcp_config.to_string());
249 }
250
251 for (key, value) in &options.extra_args {
253 let flag = if key.starts_with("--") || key.starts_with("-") {
254 key.clone()
255 } else {
256 format!("--{key}")
257 };
258 cmd.arg(&flag);
259 if let Some(val) = value {
260 cmd.arg(val);
261 }
262 }
263
264 cmd.arg("--print").arg("--").arg(&prompt);
266
267 cmd.stdout(std::process::Stdio::piped())
269 .stderr(std::process::Stdio::piped());
270
271 if let Some(max_tokens) = options.max_output_tokens {
274 let capped = max_tokens.clamp(1, 32000);
276 cmd.env("CLAUDE_CODE_MAX_OUTPUT_TOKENS", capped.to_string());
277 debug!("Setting max_output_tokens from option: {}", capped);
278 } else {
279 if let Ok(current_value) = std::env::var("CLAUDE_CODE_MAX_OUTPUT_TOKENS") {
281 if let Ok(tokens) = current_value.parse::<u32>() {
282 if tokens > 32000 {
283 warn!("CLAUDE_CODE_MAX_OUTPUT_TOKENS={} exceeds maximum safe value of 32000, overriding to 32000", tokens);
284 cmd.env("CLAUDE_CODE_MAX_OUTPUT_TOKENS", "32000");
285 }
286 } else {
287 warn!("Invalid CLAUDE_CODE_MAX_OUTPUT_TOKENS value: {}, setting to 8192", current_value);
288 cmd.env("CLAUDE_CODE_MAX_OUTPUT_TOKENS", "8192");
289 }
290 }
291 }
292
293 info!("Starting Claude CLI with --print mode");
294 debug!("Command: {:?}", cmd);
295
296 if let Some(user) = options.user.as_deref() {
297 crate::transport::subprocess::apply_process_user(&mut cmd, user)?;
298 }
299
300 let mut child = cmd.spawn().map_err(crate::SdkError::ProcessError)?;
301
302 let stdout = child
303 .stdout
304 .take()
305 .ok_or_else(|| crate::SdkError::ConnectionError("Failed to get stdout".into()))?;
306 let stderr = child
307 .stderr
308 .take()
309 .ok_or_else(|| crate::SdkError::ConnectionError("Failed to get stderr".into()))?;
310
311 let child = Arc::new(Mutex::new(child));
313 let child_clone = Arc::clone(&child);
314
315 let (tx, rx) = mpsc::channel(100);
317
318 tokio::spawn(async move {
320 let reader = BufReader::new(stderr);
321 let mut lines = reader.lines();
322 while let Ok(Some(line)) = lines.next_line().await {
323 if !line.trim().is_empty() {
324 debug!("Claude stderr: {}", line);
325 }
326 }
327 });
328
329 let tx_cleanup = tx.clone();
331
332 tokio::spawn(async move {
334 let reader = BufReader::new(stdout);
335 let mut lines = reader.lines();
336
337 while let Ok(Some(line)) = lines.next_line().await {
338 if line.trim().is_empty() {
339 continue;
340 }
341
342 debug!("Claude output: {}", line);
343
344 match serde_json::from_str::<serde_json::Value>(&line) {
346 Ok(json) => {
347 match crate::message_parser::parse_message(json) {
348 Ok(Some(message)) => {
349 if tx.send(Ok(message)).await.is_err() {
350 break;
351 }
352 }
353 Ok(None) => {
354 }
356 Err(e) => {
357 if tx.send(Err(e)).await.is_err() {
358 break;
359 }
360 }
361 }
362 }
363 Err(e) => {
364 debug!("Failed to parse JSON: {} - Line: {}", e, line);
365 }
366 }
367 }
368
369 let mut child = child_clone.lock().await;
371 match child.wait().await {
372 Ok(status) => {
373 if !status.success() {
374 let _ = tx
375 .send(Err(crate::SdkError::ProcessExited {
376 code: status.code(),
377 }))
378 .await;
379 }
380 }
381 Err(e) => {
382 let _ = tx.send(Err(crate::SdkError::ProcessError(e))).await;
383 }
384 }
385 });
386
387 tokio::spawn(async move {
389 tx_cleanup.closed().await;
391
392 let mut child = child.lock().await;
394 match child.try_wait() {
395 Ok(Some(_)) => {
396 debug!("Claude CLI process already exited");
398 }
399 Ok(None) => {
400 info!("Killing Claude CLI process on stream drop");
402 if let Err(e) = child.kill().await {
403 warn!("Failed to kill Claude CLI process: {}", e);
404 } else {
405 let _ = child.wait().await;
407 debug!("Claude CLI process killed and cleaned up");
408 }
409 }
410 Err(e) => {
411 warn!("Failed to check process status: {}", e);
412 }
413 }
414 });
415
416 Ok(ReceiverStream::new(rx))
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_query_input_from_string() {
426 let input: QueryInput = "Hello".into();
427 match input {
428 QueryInput::Text(s) => assert_eq!(s, "Hello"),
429 _ => panic!("Expected Text variant"),
430 }
431 }
432
433 #[test]
434 fn test_query_input_from_str() {
435 let input: QueryInput = "World".into();
436 match input {
437 QueryInput::Text(s) => assert_eq!(s, "World"),
438 _ => panic!("Expected Text variant"),
439 }
440 }
441
442 #[test]
443 fn test_extra_args_formatting() {
444 use std::collections::HashMap;
445
446 let mut extra_args = HashMap::new();
448 extra_args.insert("custom-flag".to_string(), Some("value".to_string()));
449 extra_args.insert("--already-dashed".to_string(), None);
450 extra_args.insert("-s".to_string(), Some("short".to_string()));
451
452 let options = ClaudeCodeOptions {
453 extra_args,
454 ..Default::default()
455 };
456
457 assert_eq!(options.extra_args.len(), 3);
459 assert!(options.extra_args.contains_key("custom-flag"));
460 assert!(options.extra_args.contains_key("--already-dashed"));
461 assert!(options.extra_args.contains_key("-s"));
462 }
463}