1use crate::{
7 errors::Result,
8 transport::InputMessage,
9 types::{ClaudeCodeOptions, Message, PermissionMode},
10};
11use futures::stream::Stream;
12use std::pin::Pin;
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tracing::{debug, info, warn};
16
17pub enum QueryInput {
19 Text(String),
21 Stream(Pin<Box<dyn Stream<Item = InputMessage> + Send>>),
23}
24
25impl From<String> for QueryInput {
26 fn from(s: String) -> Self {
27 QueryInput::Text(s)
28 }
29}
30
31impl From<&str> for QueryInput {
32 fn from(s: &str) -> Self {
33 QueryInput::Text(s.to_string())
34 }
35}
36
37pub async fn query(
116 prompt: impl Into<QueryInput>,
117 options: Option<ClaudeCodeOptions>,
118) -> Result<impl Stream<Item = Result<Message>>> {
119 let options = options.unwrap_or_default();
120 let prompt = prompt.into();
121
122 unsafe {
124 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
125 }
126
127 match prompt {
128 QueryInput::Text(text) => {
129 query_print_mode(text, options).await
131 }
132 QueryInput::Stream(_stream) => {
133 Err(crate::SdkError::NotSupported {
136 feature: "Streaming input mode not yet implemented".into(),
137 })
138 }
139 }
140}
141
142#[allow(deprecated)]
144async fn query_print_mode(
145 prompt: String,
146 options: ClaudeCodeOptions,
147) -> Result<impl Stream<Item = Result<Message>>> {
148 use std::sync::Arc;
149 use tokio::io::{AsyncBufReadExt, BufReader};
150 use tokio::process::Command;
151 use tokio::sync::Mutex;
152
153 let cli_path = crate::transport::subprocess::find_claude_cli()?;
154 let mut cmd = Command::new(&cli_path);
155
156 cmd.arg("--output-format").arg("stream-json");
158 cmd.arg("--verbose");
159
160 #[allow(deprecated)]
162 if let Some(ref system_prompt) = options.system_prompt {
163 cmd.arg("--system-prompt").arg(system_prompt);
164 }
165
166 #[allow(deprecated)]
167 if let Some(ref append_prompt) = options.append_system_prompt {
168 cmd.arg("--append-system-prompt").arg(append_prompt);
169 }
170
171 if !options.allowed_tools.is_empty() {
172 cmd.arg("--allowedTools")
173 .arg(options.allowed_tools.join(","));
174 }
175
176 if let Some(max_turns) = options.max_turns {
177 cmd.arg("--max-turns").arg(max_turns.to_string());
178 }
179
180 if !options.disallowed_tools.is_empty() {
181 cmd.arg("--disallowedTools")
182 .arg(options.disallowed_tools.join(","));
183 }
184
185 if let Some(ref model) = options.model {
186 cmd.arg("--model").arg(model);
187 }
188
189 if let Some(ref tool_name) = options.permission_prompt_tool_name {
190 cmd.arg("--permission-prompt-tool").arg(tool_name);
191 }
192
193 match options.permission_mode {
194 PermissionMode::Default => {
195 cmd.arg("--permission-mode").arg("default");
196 }
197 PermissionMode::AcceptEdits => {
198 cmd.arg("--permission-mode").arg("acceptEdits");
199 }
200 PermissionMode::Plan => {
201 cmd.arg("--permission-mode").arg("plan");
202 }
203 PermissionMode::BypassPermissions => {
204 cmd.arg("--permission-mode").arg("bypassPermissions");
205 }
206 }
207
208 if options.continue_conversation {
209 cmd.arg("--continue");
210 }
211
212 if let Some(ref resume_id) = options.resume {
213 cmd.arg("--resume").arg(resume_id);
214 }
215
216 if !options.mcp_servers.is_empty() {
217 let mcp_config = serde_json::json!({
218 "mcpServers": options.mcp_servers
219 });
220 cmd.arg("--mcp-config").arg(mcp_config.to_string());
221 }
222
223 for (key, value) in &options.extra_args {
225 let flag = if key.starts_with("--") || key.starts_with("-") {
226 key.clone()
227 } else {
228 format!("--{key}")
229 };
230 cmd.arg(&flag);
231 if let Some(val) = value {
232 cmd.arg(val);
233 }
234 }
235
236 cmd.arg("--print").arg(&prompt);
238
239 cmd.stdout(std::process::Stdio::piped())
241 .stderr(std::process::Stdio::piped());
242
243 if let Some(max_tokens) = options.max_output_tokens {
246 let capped = max_tokens.clamp(1, 32000);
248 cmd.env("CLAUDE_CODE_MAX_OUTPUT_TOKENS", capped.to_string());
249 debug!("Setting max_output_tokens from option: {}", capped);
250 } else {
251 if let Ok(current_value) = std::env::var("CLAUDE_CODE_MAX_OUTPUT_TOKENS") {
253 if let Ok(tokens) = current_value.parse::<u32>() {
254 if tokens > 32000 {
255 warn!("CLAUDE_CODE_MAX_OUTPUT_TOKENS={} exceeds maximum safe value of 32000, overriding to 32000", tokens);
256 cmd.env("CLAUDE_CODE_MAX_OUTPUT_TOKENS", "32000");
257 }
258 } else {
259 warn!("Invalid CLAUDE_CODE_MAX_OUTPUT_TOKENS value: {}, setting to 8192", current_value);
260 cmd.env("CLAUDE_CODE_MAX_OUTPUT_TOKENS", "8192");
261 }
262 }
263 }
264
265 info!("Starting Claude CLI with --print mode");
266 debug!("Command: {:?}", cmd);
267
268 let mut child = cmd.spawn().map_err(crate::SdkError::ProcessError)?;
269
270 let stdout = child
271 .stdout
272 .take()
273 .ok_or_else(|| crate::SdkError::ConnectionError("Failed to get stdout".into()))?;
274 let stderr = child
275 .stderr
276 .take()
277 .ok_or_else(|| crate::SdkError::ConnectionError("Failed to get stderr".into()))?;
278
279 let child = Arc::new(Mutex::new(child));
281 let child_clone = Arc::clone(&child);
282
283 let (tx, rx) = mpsc::channel(100);
285
286 tokio::spawn(async move {
288 let reader = BufReader::new(stderr);
289 let mut lines = reader.lines();
290 while let Ok(Some(line)) = lines.next_line().await {
291 if !line.trim().is_empty() {
292 debug!("Claude stderr: {}", line);
293 }
294 }
295 });
296
297 let tx_cleanup = tx.clone();
299
300 tokio::spawn(async move {
302 let reader = BufReader::new(stdout);
303 let mut lines = reader.lines();
304
305 while let Ok(Some(line)) = lines.next_line().await {
306 if line.trim().is_empty() {
307 continue;
308 }
309
310 debug!("Claude output: {}", line);
311
312 match serde_json::from_str::<serde_json::Value>(&line) {
314 Ok(json) => {
315 match crate::message_parser::parse_message(json) {
316 Ok(Some(message)) => {
317 if tx.send(Ok(message)).await.is_err() {
318 break;
319 }
320 }
321 Ok(None) => {
322 }
324 Err(e) => {
325 if tx.send(Err(e)).await.is_err() {
326 break;
327 }
328 }
329 }
330 }
331 Err(e) => {
332 debug!("Failed to parse JSON: {} - Line: {}", e, line);
333 }
334 }
335 }
336
337 let mut child = child_clone.lock().await;
339 match child.wait().await {
340 Ok(status) => {
341 if !status.success() {
342 let _ = tx
343 .send(Err(crate::SdkError::ProcessExited {
344 code: status.code(),
345 }))
346 .await;
347 }
348 }
349 Err(e) => {
350 let _ = tx.send(Err(crate::SdkError::ProcessError(e))).await;
351 }
352 }
353 });
354
355 tokio::spawn(async move {
357 tx_cleanup.closed().await;
359
360 let mut child = child.lock().await;
362 match child.try_wait() {
363 Ok(Some(_)) => {
364 debug!("Claude CLI process already exited");
366 }
367 Ok(None) => {
368 info!("Killing Claude CLI process on stream drop");
370 if let Err(e) = child.kill().await {
371 warn!("Failed to kill Claude CLI process: {}", e);
372 } else {
373 let _ = child.wait().await;
375 debug!("Claude CLI process killed and cleaned up");
376 }
377 }
378 Err(e) => {
379 warn!("Failed to check process status: {}", e);
380 }
381 }
382 });
383
384 Ok(ReceiverStream::new(rx))
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_query_input_from_string() {
394 let input: QueryInput = "Hello".into();
395 match input {
396 QueryInput::Text(s) => assert_eq!(s, "Hello"),
397 _ => panic!("Expected Text variant"),
398 }
399 }
400
401 #[test]
402 fn test_query_input_from_str() {
403 let input: QueryInput = "World".into();
404 match input {
405 QueryInput::Text(s) => assert_eq!(s, "World"),
406 _ => panic!("Expected Text variant"),
407 }
408 }
409
410 #[test]
411 fn test_extra_args_formatting() {
412 use std::collections::HashMap;
413
414 let mut extra_args = HashMap::new();
416 extra_args.insert("custom-flag".to_string(), Some("value".to_string()));
417 extra_args.insert("--already-dashed".to_string(), None);
418 extra_args.insert("-s".to_string(), Some("short".to_string()));
419
420 let options = ClaudeCodeOptions {
421 extra_args,
422 ..Default::default()
423 };
424
425 assert_eq!(options.extra_args.len(), 3);
427 assert!(options.extra_args.contains_key("custom-flag"));
428 assert!(options.extra_args.contains_key("--already-dashed"));
429 assert!(options.extra_args.contains_key("-s"));
430 }
431}