1#![allow(dead_code)]
3
4use crate::compact::{
5 self, get_auto_compact_threshold, get_compact_prompt, get_effective_context_window_size,
6};
7use crate::services::compact::microcompact::truncate_tool_result_content;
8use crate::error::AgentError;
9use crate::hooks::{HookInput, HookRegistry};
10use crate::tools::orchestration::{self, ToolMessageUpdate};
11use crate::types::*;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14
15fn strip_thinking(content: &str) -> String {
18 let mut result = String::new();
21 let mut in_thinking = false;
22 let mut i = 0;
23
24 while i < content.len() {
25 if content[i..].starts_with("<think>") {
27 in_thinking = true;
28 i += "<think>".len();
29 } else if content[i..].starts_with("</think>") {
30 in_thinking = false;
31 i += "</think>".len();
32 } else if !in_thinking {
33 if let Some(ch) = content[i..].chars().next() {
36 result.push(ch);
37 i += ch.len_utf8();
38 } else {
39 break;
40 }
41 } else {
42 if let Some(ch) = content[i..].chars().next() {
45 i += ch.len_utf8();
46 } else {
47 break;
48 }
49 }
50 }
51
52 result.trim().to_string()
53}
54
55fn parse_anthropic_usage(usage: &serde_json::Value) -> TokenUsage {
57 TokenUsage {
58 input_tokens: usage
59 .get("input_tokens")
60 .and_then(|v| v.as_u64())
61 .unwrap_or(0),
62 output_tokens: usage
63 .get("output_tokens")
64 .and_then(|v| v.as_u64())
65 .unwrap_or(0),
66 cache_creation_input_tokens: usage
67 .get("cache_creation_input_tokens")
68 .and_then(|v| v.as_u64()),
69 cache_read_input_tokens: usage
70 .get("cache_read_input_tokens")
71 .and_then(|v| v.as_u64()),
72 }
73}
74
75#[derive(Debug, Clone, Default)]
77pub struct AutoCompactTracking {
78 pub compacted: bool,
80 pub turn_counter: u32,
82 pub consecutive_failures: u32,
84}
85
86#[allow(dead_code)]
87pub struct QueryEngine {
88 config: QueryEngineConfig,
89 messages: Vec<crate::types::Message>,
90 turn_count: u32,
91 total_usage: TokenUsage,
92 total_cost: f64,
93 http_client: reqwest::Client,
94 tool_executors: Mutex<HashMap<String, Arc<ToolExecutor>>>,
96 hook_registry: Arc<Mutex<Option<HookRegistry>>>,
98 auto_compact_tracking: AutoCompactTracking,
100 permission_denials: Vec<PermissionDenial>,
102 last_stop_reason: Option<String>,
104}
105
106type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>;
107type ToolExecutor = dyn Fn(serde_json::Value, &ToolContext) -> BoxFuture<Result<ToolResult, AgentError>>
108 + Send
109 + Sync;
110
111#[derive(Debug, Clone, Default)]
113pub struct PermissionDenial {
114 pub tool_name: String,
115 pub tool_use_id: String,
116 pub tool_input: serde_json::Value,
117}
118
119pub struct QueryEngineConfig {
120 pub cwd: String,
121 pub model: String,
122 pub api_key: Option<String>,
123 pub base_url: Option<String>,
124 pub tools: Vec<ToolDefinition>,
125 pub system_prompt: Option<String>,
126 pub max_turns: u32,
127 pub max_budget_usd: Option<f64>,
128 pub max_tokens: u32,
129 pub can_use_tool: Option<fn(ToolDefinition, serde_json::Value) -> bool>,
132 pub on_event: Option<std::sync::Arc<dyn Fn(AgentEvent) + Send + Sync>>,
134}
135
136impl Default for QueryEngineConfig {
137 fn default() -> Self {
138 Self {
139 cwd: String::new(),
140 model: String::new(),
141 api_key: None,
142 base_url: None,
143 tools: vec![],
144 system_prompt: None,
145 max_turns: 10,
146 max_budget_usd: None,
147 max_tokens: 16384,
148 can_use_tool: None,
149 on_event: None,
150 }
151 }
152}
153
154impl QueryEngine {
155 pub fn new(config: QueryEngineConfig) -> Self {
156 Self {
157 config,
158 messages: vec![],
159 turn_count: 0,
160 total_usage: TokenUsage {
161 input_tokens: 0,
162 output_tokens: 0,
163 cache_creation_input_tokens: None,
164 cache_read_input_tokens: None,
165 },
166 total_cost: 0.0,
167 http_client: reqwest::Client::new(),
168 tool_executors: Mutex::new(HashMap::new()),
169 hook_registry: Arc::new(Mutex::new(None)),
170 auto_compact_tracking: AutoCompactTracking::default(),
171 permission_denials: Vec::new(),
172 last_stop_reason: None,
173 }
174 }
175
176 pub fn register_tool<F>(&mut self, name: String, executor: F)
178 where
179 F: Fn(serde_json::Value, &ToolContext) -> BoxFuture<Result<ToolResult, AgentError>>
180 + Send
181 + Sync
182 + 'static,
183 {
184 self.tool_executors
185 .lock()
186 .unwrap()
187 .insert(name, Arc::new(executor));
188 }
189
190 pub fn set_messages(&mut self, messages: Vec<crate::types::Message>) {
192 self.messages = messages;
193 }
194
195 pub async fn execute_tool(
197 &mut self,
198 name: &str,
199 input: serde_json::Value,
200 ) -> Result<ToolResult, AgentError> {
201 let context = ToolContext {
202 cwd: self.config.cwd.clone(),
203 abort_signal: None,
204 };
205
206 let executor = {
208 let executors = self.tool_executors.lock().unwrap();
209 executors.get(name).cloned()
210 };
211
212 if let Some(executor) = executor {
213 let tool_use_id = uuid::Uuid::new_v4().to_string();
216 if let Some(can_use_tool_fn) = &self.config.can_use_tool {
217 if let Some(tool_def) = self.config.tools.iter().find(|t| &t.name == name) {
219 if !can_use_tool_fn(tool_def.clone(), input.clone()) {
221 self.permission_denials.push(PermissionDenial {
223 tool_name: name.to_string(),
224 tool_use_id: tool_use_id.clone(),
225 tool_input: input.clone(),
226 });
227 return Err(AgentError::Tool(format!(
228 "Tool '{}' permission denied",
229 name
230 )));
231 }
232 }
233 }
234
235 let tool_use_id = uuid::Uuid::new_v4().to_string();
238
239 if let Some(ref cb) = self.config.on_event {
241 cb(AgentEvent::ToolStart {
242 tool_name: name.to_string(),
243 tool_call_id: tool_use_id.clone(),
244 input: input.clone(),
245 });
246 }
247
248 self.run_pre_tool_use_hooks(name, &input, &tool_use_id)
249 .await?;
250
251 let result = executor(input, &context).await;
253
254 if let Some(ref cb) = self.config.on_event {
256 match &result {
257 Ok(tool_result) => {
258 cb(AgentEvent::ToolComplete {
259 tool_name: name.to_string(),
260 tool_call_id: tool_use_id.clone(),
261 result: tool_result.clone(),
262 });
263 }
264 Err(e) => {
265 cb(AgentEvent::ToolError {
266 tool_name: name.to_string(),
267 tool_call_id: tool_use_id.clone(),
268 error: e.to_string(),
269 });
270 }
271 }
272 }
273
274 match &result {
276 Ok(tool_result) => {
277 self.run_post_tool_use_hooks(name, tool_result, &tool_use_id)
278 .await;
279 }
280 Err(e) => {
281 self.run_post_tool_use_failure_hooks(name, e, &tool_use_id)
282 .await;
283 }
284 }
285
286 result
287 } else {
288 Err(AgentError::Tool(format!("Tool '{}' not found", name)))
289 }
290 }
291
292 pub fn set_hook_registry(&self, registry: HookRegistry) {
294 let mut guard = self.hook_registry.lock().unwrap();
295 *guard = Some(registry);
296 }
297
298 pub fn set_event_callback<F>(&mut self, callback: F)
300 where
301 F: Fn(AgentEvent) + Send + Sync + 'static,
302 {
303 self.config.on_event = Some(std::sync::Arc::new(callback));
304 }
305
306 async fn run_pre_tool_use_hooks(
308 &self,
309 tool_name: &str,
310 tool_input: &serde_json::Value,
311 tool_use_id: &str,
312 ) -> Result<(), AgentError> {
313 let has_hooks = {
315 let guard = self.hook_registry.lock().unwrap();
316 guard
317 .as_ref()
318 .map(|r| r.has_hooks("PreToolUse"))
319 .unwrap_or(false)
320 };
321
322 if !has_hooks {
323 return Ok(());
324 }
325
326 let input = HookInput {
328 event: "PreToolUse".to_string(),
329 tool_name: Some(tool_name.to_string()),
330 tool_input: Some(tool_input.clone()),
331 tool_output: None,
332 tool_use_id: Some(tool_use_id.to_string()),
333 session_id: None,
334 cwd: Some(self.config.cwd.clone()),
335 error: None,
336 };
337
338 let registry = {
340 let guard = self.hook_registry.lock().unwrap();
341 guard.as_ref().cloned()
342 };
343
344 if let Some(registry) = registry {
345 let results = registry.execute("PreToolUse", input).await;
346
347 for output in results {
349 if let Some(block) = output.block {
350 if block {
351 return Err(AgentError::Tool(format!(
352 "Tool '{}' blocked by PreToolUse hook",
353 tool_name
354 )));
355 }
356 }
357 }
358 }
359 Ok(())
360 }
361
362 async fn run_post_tool_use_hooks(
364 &self,
365 tool_name: &str,
366 tool_output: &ToolResult,
367 tool_use_id: &str,
368 ) {
369 let has_hooks = {
370 let guard = self.hook_registry.lock().unwrap();
371 guard
372 .as_ref()
373 .map(|r| r.has_hooks("PostToolUse"))
374 .unwrap_or(false)
375 };
376
377 if !has_hooks {
378 return;
379 }
380
381 let input = HookInput {
382 event: "PostToolUse".to_string(),
383 tool_name: Some(tool_name.to_string()),
384 tool_input: None,
385 tool_output: Some(serde_json::json!({
386 "result_type": tool_output.result_type,
387 "content": tool_output.content,
388 "is_error": tool_output.is_error,
389 })),
390 tool_use_id: Some(tool_use_id.to_string()),
391 session_id: None,
392 cwd: Some(self.config.cwd.clone()),
393 error: None,
394 };
395
396 let registry = {
397 let guard = self.hook_registry.lock().unwrap();
398 guard.as_ref().cloned()
399 };
400
401 if let Some(registry) = registry {
402 let _ = registry.execute("PostToolUse", input).await;
403 }
404 }
405
406 async fn run_post_tool_use_failure_hooks(
408 &self,
409 tool_name: &str,
410 error: &AgentError,
411 tool_use_id: &str,
412 ) {
413 let has_hooks = {
414 let guard = self.hook_registry.lock().unwrap();
415 guard
416 .as_ref()
417 .map(|r| r.has_hooks("PostToolUseFailure"))
418 .unwrap_or(false)
419 };
420
421 if !has_hooks {
422 return;
423 }
424
425 let input = HookInput {
426 event: "PostToolUseFailure".to_string(),
427 tool_name: Some(tool_name.to_string()),
428 tool_input: None,
429 tool_output: None,
430 tool_use_id: Some(tool_use_id.to_string()),
431 session_id: None,
432 cwd: Some(self.config.cwd.clone()),
433 error: Some(error.to_string()),
434 };
435
436 let registry = {
437 let guard = self.hook_registry.lock().unwrap();
438 guard.as_ref().cloned()
439 };
440
441 if let Some(registry) = registry {
442 let _ = registry.execute("PostToolUseFailure", input).await;
443 }
444 }
445
446 pub fn get_turn_count(&self) -> u32 {
447 self.turn_count
448 }
449
450 pub fn get_usage(&self) -> TokenUsage {
452 self.total_usage.clone()
453 }
454
455 pub fn get_messages(&self) -> Vec<crate::types::Message> {
456 self.messages.clone()
457 }
458
459 async fn do_auto_compact(&mut self) -> Result<bool, AgentError> {
463 let token_count = compact::estimate_token_count(&self.messages, self.config.max_tokens);
464 let threshold = get_auto_compact_threshold(&self.config.model);
465
466 if token_count <= threshold {
468 return Ok(false);
469 }
470
471 let summary = self.generate_summary().await?;
473
474 let keep_last = 4;
476 let messages_to_keep: Vec<Message> = if self.messages.len() > keep_last {
477 self.messages[self.messages.len() - keep_last..].to_vec()
478 } else {
479 self.messages.clone()
480 };
481
482 let boundary_msg = Message {
484 role: MessageRole::System,
485 content: format!("[Previous conversation summarized]\n\n{}", summary),
486 ..Default::default()
487 };
488
489 let mut new_messages = vec![boundary_msg];
491 new_messages.extend(messages_to_keep);
492
493 let _new_token_count = compact::estimate_token_count(&new_messages, self.config.max_tokens);
494
495 self.messages = new_messages;
496 Ok(true)
497 }
498
499 async fn generate_summary(&self) -> Result<String, AgentError> {
502 let compact_prompt = get_compact_prompt();
503
504 let mut summary_messages = vec![Message {
506 role: MessageRole::User,
507 content: compact_prompt,
508 ..Default::default()
509 }];
510
511 for msg in &self.messages {
513 if let MessageRole::System = msg.role {
514 if msg.content.contains("compacted") || msg.content.contains("summarized") {
516 continue;
517 }
518 }
519 summary_messages.push(msg.clone());
520 }
521
522 let max_summary_tokens = 2048u32; let (truncated_summary_messages, estimated_tokens) = compact::truncate_messages_for_summary(
526 &summary_messages,
527 &self.config.model,
528 max_summary_tokens,
529 );
530
531 if estimated_tokens > 150000 {
533 return Err(AgentError::Api(format!(
534 "Cannot generate summary: estimated {} tokens exceeds safe limit",
535 estimated_tokens
536 )));
537 }
538
539 let summary_messages = truncated_summary_messages;
541
542 let api_key = self
544 .config
545 .api_key
546 .as_ref()
547 .ok_or_else(|| AgentError::Api("API key not provided".to_string()))?;
548
549 let base_url = self
550 .config
551 .base_url
552 .as_ref()
553 .map(|s| s.as_str())
554 .unwrap_or("https://api.anthropic.com");
555
556 let model = &self.config.model;
558
559 let api_summary_messages: Vec<serde_json::Value> = summary_messages
562 .iter()
563 .map(|msg| {
564 let role_str = match msg.role {
565 MessageRole::User => "user",
566 MessageRole::Assistant => "assistant",
567 MessageRole::Tool => "user", MessageRole::System => "system",
569 };
570 let mut msg_json = serde_json::json!({
571 "role": role_str,
572 "content": msg.content
573 });
574 if let Some(tool_call_id) = &msg.tool_call_id {
575 msg_json["tool_call_id"] = serde_json::json!(tool_call_id);
576 }
577 msg_json
578 })
579 .collect();
580
581 let request_body = serde_json::json!({
584 "model": model,
585 "max_tokens": 2048,
586 "messages": api_summary_messages,
587 });
588
589 let client = reqwest::Client::new();
590 let url = format!("{}/v1/chat/completions", base_url);
591 let response = client
592 .post(&url)
593 .header("Authorization", format!("Bearer {}", api_key))
594 .header("Content-Type", "application/json")
595 .json(&request_body)
596 .send()
597 .await
598 .map_err(|e| AgentError::Api(format!("Failed to send summary request: {}", e)))?;
599
600 let response_text = response
601 .text()
602 .await
603 .map_err(|e| AgentError::Api(format!("Failed to read summary response: {}", e)))?;
604
605 let response_json: serde_json::Value =
607 serde_json::from_str(&response_text).map_err(|e| {
608 AgentError::Api(format!(
609 "Failed to parse summary response: {} - {}",
610 e, response_text
611 ))
612 })?;
613
614 if let Some(error) = response_json.get("error") {
616 return Err(AgentError::Api(format!("Summary API error: {}", error)));
617 }
618
619 let summary = extract_text_from_response(&response_json);
621
622 if summary.is_empty() {
623 return Err(AgentError::Api("Summary response was empty".to_string()));
624 }
625
626 let parsed_summary = parse_compact_summary(&summary);
629
630 Ok(parsed_summary)
631 }
632
633 pub async fn submit_message(&mut self, prompt: &str) -> Result<(String, crate::types::ExitReason), AgentError> {
634 self.messages.push(crate::types::Message {
636 role: crate::types::MessageRole::User,
637 content: prompt.to_string(),
638 ..Default::default()
639 });
640
641 let threshold = get_auto_compact_threshold(&self.config.model);
647 let token_count = compact::estimate_token_count(&self.messages, self.config.max_tokens);
648
649 if self.auto_compact_tracking.consecutive_failures < 3 && token_count > threshold {
650 match self.do_auto_compact().await {
652 Ok(true) => {
653 self.auto_compact_tracking.compacted = true;
655 self.auto_compact_tracking.consecutive_failures = 0;
656 }
657 Ok(false) => {
658 }
660 Err(e) => {
661 self.auto_compact_tracking.consecutive_failures += 1;
663 eprintln!("Auto-compact failed: {}", e);
664 }
665 }
666 }
667
668 if let Some(ref cb) = self.config.on_event {
671 cb(AgentEvent::Thinking { turn: 1 });
672 }
673
674 let mut max_tool_turns = 10;
676 while max_tool_turns > 0 {
677 max_tool_turns -= 1;
678
679 self.auto_compact_tracking.turn_counter += 1;
681
682 let token_count = compact::estimate_token_count(&self.messages, self.config.max_tokens);
684 let threshold = get_auto_compact_threshold(&self.config.model);
685 let _effective_window = get_effective_context_window_size(&self.config.model);
686
687 if self.auto_compact_tracking.consecutive_failures < 3 && token_count > threshold {
691 match self.do_auto_compact().await {
693 Ok(true) => {
694 self.auto_compact_tracking.compacted = true;
696 self.auto_compact_tracking.consecutive_failures = 0;
697 continue;
699 }
700 Ok(false) => {
701 }
703 Err(e) => {
704 self.auto_compact_tracking.consecutive_failures += 1;
706 eprintln!("Auto-compact failed: {}", e);
707 }
708 }
709 }
710
711 self.auto_compact_tracking.compacted = false;
713
714 let api_messages = self.build_api_messages()?;
716
717 let api_key = self
719 .config
720 .api_key
721 .as_ref()
722 .ok_or_else(|| AgentError::Api("API key not provided".to_string()))?;
723
724 let base_url = self
725 .config
726 .base_url
727 .as_ref()
728 .map(|s| s.as_str())
729 .unwrap_or("https://api.anthropic.com");
730
731 let model = &self.config.model;
732
733 let mut request_body = serde_json::json!({
737 "model": model,
738 "max_tokens": self.config.max_tokens,
739 "messages": api_messages,
740 "stream": true
741 });
742
743 if let Some(system_prompt) = &self.config.system_prompt {
745 request_body["system"] = serde_json::json!(system_prompt);
746 }
747
748 if !self.config.tools.is_empty() {
750 let use_anthropic_format = base_url.contains("anthropic.com");
755
756 let tools: Vec<serde_json::Value> = self
757 .config
758 .tools
759 .iter()
760 .map(|t| {
761 if use_anthropic_format {
762 serde_json::json!({
765 "type": "function",
766 "name": t.name,
767 "description": t.description,
768 "input_schema": t.input_schema
769 })
770 } else {
771 serde_json::json!({
774 "type": "function",
775 "function": {
776 "name": t.name,
777 "description": t.description,
778 "parameters": t.input_schema
779 }
780 })
781 }
782 })
783 .collect();
784 request_body["tools"] = serde_json::json!(tools);
785 }
786
787 let url = if base_url.contains("anthropic.com") {
790 format!("{}/v1/messages", base_url)
791 } else {
792 format!("{}/v1/chat/completions", base_url)
793 };
794
795 let streaming_result = match make_anthropic_streaming_request(
798 &self.http_client,
799 &url,
800 api_key,
801 request_body.clone(),
802 self.config.on_event.clone(),
803 )
804 .await
805 {
806 Ok(result) => result,
807 Err(e) => {
808 eprintln!("Streaming failed, falling back to non-streaming: {}", e);
811
812 make_nonstreaming_request(
814 &self.http_client,
815 &url,
816 api_key,
817 request_body,
818 self.config.on_event.clone(),
819 )
820 .await?
821 }
822 };
823
824 if streaming_result.tool_calls.is_empty() {
826 let response_text = streaming_result.content;
828
829 let cleaned = strip_thinking(&response_text);
832 let final_text = if cleaned.is_empty() && !response_text.is_empty() {
833 response_text.clone()
834 } else {
835 cleaned
836 };
837
838 self.total_usage.input_tokens += streaming_result.usage.input_tokens;
840 self.total_usage.output_tokens += streaming_result.usage.output_tokens;
841
842 self.messages.push(crate::types::Message {
844 role: crate::types::MessageRole::Assistant,
845 content: response_text.clone(),
846 ..Default::default()
847 });
848
849 let next_turn_count = self.turn_count + 1;
851 if self.config.max_turns > 0 && next_turn_count > self.config.max_turns {
852 if let Some(ref cb) = self.config.on_event {
854 cb(AgentEvent::MaxTurnsReached {
855 max_turns: self.config.max_turns,
856 turn_count: next_turn_count,
857 });
858 }
859 return Ok((final_text, crate::types::ExitReason::MaxTurns {
861 max_turns: self.config.max_turns,
862 turn_count: next_turn_count,
863 }));
864 }
865
866 self.turn_count = next_turn_count;
868
869 if let Some(ref cb) = self.config.on_event {
871 cb(AgentEvent::Thinking {
872 turn: self.turn_count,
873 });
874 }
875
876 return Ok((final_text, crate::types::ExitReason::Completed));
878 }
879
880 let tool_calls = streaming_result.tool_calls;
882
883 let mut tool_call_structs: Vec<crate::types::ToolCall> = Vec::new();
885 for tc in &tool_calls {
886 let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("").to_string();
887 let id = tc.get("id").and_then(|i| i.as_str()).unwrap_or("").to_string();
888 let arguments = tc.get("arguments").cloned().unwrap_or(serde_json::Value::Null);
889 tool_call_structs.push(crate::types::ToolCall {
890 id,
891 name,
892 arguments,
893 });
894 }
895
896 let tool_context = crate::types::ToolContext {
899 cwd: self.config.cwd.clone(),
900 abort_signal: None,
901 };
902
903 let tool_executors = Arc::new(self.tool_executors.lock().unwrap().clone());
906 let tools = self.config.tools.clone();
907 let can_use_tool = self.config.can_use_tool;
908 let cwd = self.config.cwd.clone();
909 let on_event = self.config.on_event.clone();
910
911 let executor = move |name: String, args: serde_json::Value, tool_call_id: String| {
912 let tool_executors = tool_executors.clone();
913 let tools = tools.clone();
914 let can_use_tool = can_use_tool;
915 let cwd = cwd.clone();
916 let on_event = on_event.clone();
917 async move {
918 if let Some(ref cb) = on_event {
920 cb(AgentEvent::ToolStart {
921 tool_name: name.clone(),
922 tool_call_id: tool_call_id.clone(),
923 input: args.clone(),
924 });
925 }
926
927 let context = crate::types::ToolContext {
929 cwd,
930 abort_signal: None,
931 };
932
933 let executor = tool_executors.get(&name).cloned();
934
935 if let Some(executor) = executor {
936 if let Some(can_use_fn) = can_use_tool {
938 if let Some(tool_def) = tools.iter().find(|t| &t.name == &name) {
939 if !can_use_fn(tool_def.clone(), args.clone()) {
940 return Err(crate::error::AgentError::Tool(format!(
941 "Tool '{}' permission denied",
942 name
943 )));
944 }
945 }
946 }
947
948 let result = executor(args, &context).await;
949
950 if let Some(ref cb) = on_event {
952 match &result {
953 Ok(tool_result) => {
954 cb(AgentEvent::ToolComplete {
955 tool_name: name.clone(),
956 tool_call_id: tool_call_id.clone(),
957 result: tool_result.clone(),
958 });
959 }
960 Err(e) => {
961 cb(AgentEvent::ToolError {
962 tool_name: name.clone(),
963 tool_call_id: tool_call_id.clone(),
964 error: e.to_string(),
965 });
966 }
967 }
968 }
969
970 result
971 } else {
972 let err = crate::error::AgentError::Tool(format!(
973 "Tool '{}' not found",
974 name
975 ));
976 if let Some(ref cb) = on_event {
978 cb(AgentEvent::ToolError {
979 tool_name: name.clone(),
980 tool_call_id: tool_call_id.clone(),
981 error: err.to_string(),
982 });
983 }
984 Err(err)
985 }
986 }
987 };
988
989 let assistant_msg = crate::types::Message {
992 role: crate::types::MessageRole::Assistant,
993 content: format!("Calling tool(s): {:?}", tool_calls.iter().map(|tc| tc.get("name").and_then(|n| n.as_str()).unwrap_or("")).collect::<Vec<_>>()),
994 tool_calls: Some(tool_call_structs.clone()),
995 ..Default::default()
996 };
997 self.messages.push(assistant_msg);
998
999 let updates = orchestration::run_tools(
1000 tool_call_structs,
1001 self.config.tools.clone(),
1002 tool_context,
1003 executor,
1004 )
1005 .await;
1006
1007 for update in updates {
1009 if let Some(message) = update.message {
1010 let truncated_content =
1013 truncate_tool_result_content(&message.content, "");
1014 let mut msg = message;
1015 msg.content = truncated_content;
1016 self.messages.push(msg);
1017 }
1018 }
1019
1020 let next_turn_count = self.turn_count + 1;
1022 if self.config.max_turns > 0 && next_turn_count > self.config.max_turns {
1023 if let Some(ref cb) = self.config.on_event {
1025 cb(AgentEvent::MaxTurnsReached {
1026 max_turns: self.config.max_turns,
1027 turn_count: next_turn_count,
1028 });
1029 }
1030 let final_text = self
1032 .messages
1033 .iter()
1034 .filter(|m| m.role == crate::types::MessageRole::Assistant)
1035 .last()
1036 .map(|m| m.content.clone())
1037 .unwrap_or_else(|| "Max turns reached".to_string());
1038 let final_text = strip_thinking(&final_text);
1039 if let Some(ref cb) = self.config.on_event {
1040 cb(AgentEvent::Done {
1041 result: crate::types::QueryResult {
1042 text: final_text.clone(),
1043 usage: self.total_usage.clone(),
1044 num_turns: self.turn_count,
1045 duration_ms: 0,
1046 exit_reason: crate::types::ExitReason::default(),
1047 },
1048 });
1049 }
1050 return Ok((final_text, crate::types::ExitReason::default()));
1051 }
1052
1053 self.turn_count = next_turn_count;
1056
1057 if let Some(ref cb) = self.config.on_event {
1059 cb(AgentEvent::Thinking {
1060 turn: self.turn_count,
1061 });
1062 }
1063
1064 }
1066
1067 let final_text = self
1069 .messages
1070 .iter()
1071 .filter(|m| m.role == crate::types::MessageRole::Assistant)
1072 .last()
1073 .map(|m| m.content.clone())
1074 .unwrap_or_else(|| "Max tool execution turns reached".to_string());
1075
1076 let final_text = strip_thinking(&final_text);
1078
1079 if let Some(ref cb) = self.config.on_event {
1081 cb(AgentEvent::Done {
1082 result: crate::types::QueryResult {
1083 text: final_text.clone(),
1084 usage: self.total_usage.clone(),
1085 num_turns: self.turn_count,
1086 duration_ms: 0, exit_reason: crate::types::ExitReason::Completed,
1088 },
1089 });
1090 }
1091
1092 Ok((final_text, crate::types::ExitReason::Completed))
1093 }
1094
1095 fn build_api_messages(&self) -> Result<Vec<serde_json::Value>, AgentError> {
1096 let base_url = self.config.base_url.as_deref().unwrap_or("https://api.anthropic.com");
1098 let is_anthropic = base_url.contains("anthropic.com");
1099
1100 let mut api_messages: Vec<serde_json::Value> = Vec::new();
1101
1102 for msg in &self.messages {
1105 match msg.role {
1106 crate::types::MessageRole::User => {
1107 api_messages.push(serde_json::json!({
1109 "role": "user",
1110 "content": msg.content
1111 }));
1112 }
1113 crate::types::MessageRole::Assistant => {
1114 if let Some(tool_calls) = &msg.tool_calls {
1116 if is_anthropic {
1117 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1119
1120 if !msg.content.is_empty()
1122 && msg.content
1123 != format!(
1124 "Calling tool: {} with args: ",
1125 tool_calls.first().map(|t| t.name.as_str()).unwrap_or("")
1126 )
1127 {
1128 content_blocks.push(serde_json::json!({
1129 "type": "text",
1130 "text": msg.content
1131 }));
1132 }
1133
1134 for tc in tool_calls {
1136 content_blocks.push(serde_json::json!({
1137 "type": "tool_use",
1138 "id": tc.id,
1139 "name": tc.name,
1140 "input": tc.arguments
1141 }));
1142 }
1143
1144 api_messages.push(serde_json::json!({
1145 "role": "assistant",
1146 "content": content_blocks
1147 }));
1148 } else {
1149 let mut openai_tool_calls: Vec<serde_json::Value> = Vec::new();
1152 for tc in tool_calls {
1153 openai_tool_calls.push(serde_json::json!({
1154 "id": tc.id,
1155 "type": "function",
1156 "function": {
1157 "name": tc.name,
1158 "arguments": serde_json::to_string(&tc.arguments).unwrap_or_default()
1159 }
1160 }));
1161 }
1162
1163 api_messages.push(serde_json::json!({
1164 "role": "assistant",
1165 "content": msg.content,
1166 "tool_calls": openai_tool_calls
1167 }));
1168 }
1169 } else {
1170 api_messages.push(serde_json::json!({
1172 "role": "assistant",
1173 "content": msg.content
1174 }));
1175 }
1176 }
1177 crate::types::MessageRole::Tool => {
1178 let tool_use_id = msg.tool_call_id.clone().unwrap_or_default();
1180
1181 let content = if msg.is_error == Some(true) {
1183 format!("<tool_use_error>{}</tool_use_error>", msg.content)
1184 } else {
1185 msg.content.clone()
1186 };
1187
1188 api_messages.push(serde_json::json!({
1193 "role": "tool",
1194 "content": content,
1195 "tool_call_id": tool_use_id
1196 }));
1197 }
1198 crate::types::MessageRole::System => {
1199 api_messages.push(serde_json::json!({
1201 "role": "user",
1202 "content": msg.content
1203 }));
1204 }
1205 }
1206 }
1207 Ok(api_messages)
1208 }
1209}
1210
1211fn calculate_compaction_messages(
1215 messages: &[crate::types::Message],
1216 target_tokens: u32,
1217) -> Vec<crate::types::Message> {
1218 if messages.len() <= 4 {
1219 return messages.to_vec();
1221 }
1222
1223 let avg_tokens_per_msg = 500;
1225 let target_message_count = (target_tokens / avg_tokens_per_msg).max(10) as usize;
1226
1227 let keep_first = 2;
1230 let keep_last = target_message_count.saturating_sub(keep_first);
1231
1232 if messages.len() <= keep_first + keep_last {
1233 return messages.to_vec();
1234 }
1235
1236 let first_part = &messages[..keep_first];
1237 let last_part = &messages[messages.len() - keep_last..];
1238
1239 let mut result = Vec::with_capacity(keep_first + keep_last);
1240 result.extend(first_part.iter().cloned());
1241 result.extend(last_part.iter().cloned());
1242 result
1243}
1244
1245fn extract_text_from_response(response: &serde_json::Value) -> String {
1247 if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
1249 if let Some(first_choice) = choices.first() {
1250 if let Some(content) = first_choice.get("message").and_then(|m| m.get("content")) {
1251 if let Some(text) = content.as_str() {
1252 return text.to_string();
1253 }
1254 }
1255 }
1256 }
1257
1258 if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
1260 for block in content {
1261 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
1262 return text.to_string();
1263 }
1264 }
1265 }
1266
1267 String::new()
1268}
1269
1270fn parse_compact_summary(raw_summary: &str) -> String {
1273 if let Some(start) = raw_summary.find("<summary>") {
1275 if let Some(end) = raw_summary.find("</summary>") {
1276 let mut summary = raw_summary[start + 9..end].trim().to_string();
1277
1278 if let Some(after) = raw_summary.find("</summary>") {
1280 let remaining = raw_summary[after + 11..].trim();
1281 if !remaining.is_empty() && !remaining.starts_with('<') {
1282 summary.push_str("\n\n");
1283 summary.push_str(remaining);
1284 }
1285 }
1286
1287 return if summary.is_empty() {
1289 raw_summary.trim().to_string()
1290 } else {
1291 summary
1292 };
1293 }
1294 }
1295
1296 let mut cleaned = raw_summary.to_string();
1298 if let Some(analysis_start) = cleaned.find("<analysis>") {
1299 if let Some(analysis_end) = cleaned.find("</analysis>") {
1300 cleaned = format!(
1301 "{}{}",
1302 &cleaned[..analysis_start],
1303 cleaned[analysis_end + 11..].trim()
1304 );
1305 }
1306 }
1307
1308 cleaned.trim().to_string()
1309}
1310
1311fn extract_tool_calls(response: &serde_json::Value) -> Vec<serde_json::Value> {
1312 if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
1314 if let Some(first_choice) = choices.first() {
1315 if let Some(message) = first_choice.get("message") {
1316 if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
1317 if !tool_calls.is_empty() {
1318 return tool_calls
1319 .iter()
1320 .map(|tc| {
1321 let func = tc.get("function");
1322 let name = func
1323 .and_then(|f| f.get("name"))
1324 .cloned()
1325 .unwrap_or(serde_json::Value::Null);
1326 let args = func.and_then(|f| f.get("arguments"));
1328 let arguments = if let Some(args_val) = args {
1329 if let Some(arg_str) = args_val.as_str() {
1330 serde_json::from_str(arg_str).unwrap_or(args_val.clone())
1332 } else {
1333 args_val.clone()
1334 }
1335 } else {
1336 serde_json::Value::Null
1337 };
1338 let id = tc.get("id").cloned();
1340 let mut result = serde_json::json!({
1341 "name": name,
1342 "arguments": arguments,
1343 });
1344 if let Some(id_val) = id {
1345 result["id"] = id_val;
1346 }
1347 result
1348 })
1349 .collect();
1350 }
1351 }
1352 }
1353 }
1354 }
1355
1356 vec![]
1357}
1358fn extract_response_text(response: &serde_json::Value) -> String {
1361 if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
1363 if let Some(first_choice) = choices.first() {
1364 if let Some(message) = first_choice.get("message") {
1365 if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
1366 return content.to_string();
1367 }
1368 }
1369 }
1370 }
1371
1372 if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
1374 for block in content {
1375 if let Some(block_type) = block.get("type").and_then(|t| t.as_str()) {
1376 match block_type {
1377 "text" => {
1378 if let Some(t) = block.get("text").and_then(|t| t.as_str()) {
1379 return t.to_string();
1380 }
1381 }
1382 _ => {}
1383 }
1384 }
1385 }
1386 }
1387
1388 String::new()
1389}
1390
1391fn extract_usage(response: &serde_json::Value) -> TokenUsage {
1392 if let Some(usage) = response.get("usage") {
1394 return TokenUsage {
1395 input_tokens: usage
1396 .get("prompt_tokens")
1397 .and_then(|v| v.as_u64())
1398 .unwrap_or(0)
1399 + usage
1400 .get("completion_tokens")
1401 .and_then(|v| v.as_u64())
1402 .unwrap_or(0),
1403 output_tokens: usage
1404 .get("completion_tokens")
1405 .and_then(|v| v.as_u64())
1406 .unwrap_or(0),
1407 cache_creation_input_tokens: None,
1408 cache_read_input_tokens: None,
1409 };
1410 }
1411
1412 let usage = response.get("usage");
1414 TokenUsage {
1415 input_tokens: usage
1416 .and_then(|u| u.get("input_tokens"))
1417 .and_then(|v| v.as_u64())
1418 .unwrap_or(0),
1419 output_tokens: usage
1420 .and_then(|u| u.get("output_tokens"))
1421 .and_then(|v| v.as_u64())
1422 .unwrap_or(0),
1423 cache_creation_input_tokens: usage
1424 .and_then(|u| u.get("cache_creation_input_tokens"))
1425 .and_then(|v| v.as_u64()),
1426 cache_read_input_tokens: usage
1427 .and_then(|u| u.get("cache_read_input_tokens"))
1428 .and_then(|v| v.as_u64()),
1429 }
1430}
1431
1432#[derive(Debug, Default)]
1434struct StreamingResult {
1435 content: String,
1436 tool_calls: Vec<serde_json::Value>,
1437 usage: TokenUsage,
1438}
1439
1440async fn make_nonstreaming_request(
1443 client: &reqwest::Client,
1444 url: &str,
1445 api_key: &str,
1446 mut request_body: serde_json::Value,
1447 on_event: Option<Arc<dyn Fn(AgentEvent) + Send + Sync>>,
1448) -> Result<StreamingResult, AgentError> {
1449 request_body["stream"] = serde_json::json!(false);
1451
1452 let is_anthropic = url.contains("anthropic.com");
1454
1455 let response = if is_anthropic {
1456 client
1458 .post(url)
1459 .header("x-api-key", api_key)
1460 .header("anthropic-version", "2023-06-01")
1461 .header("Content-Type", "application/json")
1462 .json(&request_body)
1463 .send()
1464 .await
1465 .map_err(|e| AgentError::Api(format!("Non-streaming request failed: {}", e)))?
1466 } else {
1467 client
1469 .post(url)
1470 .header("Authorization", format!("Bearer {}", api_key))
1471 .header("Content-Type", "application/json")
1472 .json(&request_body)
1473 .send()
1474 .await
1475 .map_err(|e| AgentError::Api(format!("Non-streaming request failed: {}", e)))?
1476 };
1477
1478 let status = response.status();
1479 if !status.is_success() {
1480 let error_text = response.text().await.unwrap_or_default();
1481 return Err(AgentError::Api(format!(
1482 "Non-streaming API error {}: {}",
1483 status, error_text
1484 )));
1485 }
1486
1487 if let Some(ref cb) = on_event {
1489 cb(AgentEvent::MessageStart {
1490 message_id: uuid::Uuid::new_v4().to_string(),
1491 });
1492 }
1493
1494 let response_text = response
1496 .text()
1497 .await
1498 .map_err(|e| AgentError::Api(format!("Failed to read non-streaming response: {}", e)))?;
1499
1500 let response_json: serde_json::Value =
1502 serde_json::from_str(&response_text).map_err(|e| {
1503 AgentError::Api(format!(
1504 "Failed to parse non-streaming response: {} - {}",
1505 e, response_text
1506 ))
1507 })?;
1508
1509 if let Some(error) = response_json.get("error") {
1511 return Err(AgentError::Api(format!("API error: {}", error)));
1512 }
1513
1514 let mut result = StreamingResult::default();
1515
1516 if let Some(content) = response_json.get("content").and_then(|c| c.as_array()) {
1518 for block in content {
1519 let block_type = block.get("type").and_then(|t| t.as_str());
1520 match block_type {
1521 Some("text") => {
1522 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
1523 result.content.push_str(text);
1524 }
1525 }
1526 Some("tool_use") => {
1527 let tool_id = block.get("id").and_then(|i| i.as_str()).unwrap_or("");
1528 let tool_name = block.get("name").and_then(|n| n.as_str()).unwrap_or("");
1529 let tool_input = block.get("input").cloned().unwrap_or(serde_json::Value::Null);
1530
1531 result.tool_calls.push(serde_json::json!({
1532 "id": tool_id,
1533 "name": tool_name,
1534 "arguments": tool_input,
1535 }));
1536 }
1537 _ => {}
1538 }
1539 }
1540 if let Some(usage) = response_json.get("usage") {
1542 result.usage = parse_anthropic_usage(usage);
1543 }
1544 }
1545 else if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
1547 if let Some(first_choice) = choices.first() {
1548 if let Some(message) = first_choice.get("message") {
1549 if let Some(content) = message.get("content").and_then(|c| c.as_str()) {
1551 result.content = content.to_string();
1552 }
1553 if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
1555 for tc in tool_calls {
1556 let id = tc.get("id").and_then(|i| i.as_str()).unwrap_or("");
1557 let func = tc.get("function");
1558 let name = func
1559 .and_then(|f| f.get("name"))
1560 .and_then(|n| n.as_str())
1561 .unwrap_or("");
1562 let args = func.and_then(|f| f.get("arguments"));
1563 let args_val = if let Some(args_str) = args.and_then(|a| a.as_str()) {
1564 serde_json::from_str(args_str).unwrap_or(serde_json::Value::Null)
1565 } else {
1566 args.cloned().unwrap_or(serde_json::Value::Null)
1567 };
1568 result.tool_calls.push(serde_json::json!({
1569 "id": id,
1570 "name": name,
1571 "arguments": args_val,
1572 }));
1573 }
1574 }
1575 }
1576 }
1577 if let Some(usage) = response_json.get("usage") {
1579 result.usage = TokenUsage {
1580 input_tokens: usage
1581 .get("prompt_tokens")
1582 .and_then(|v| v.as_u64())
1583 .unwrap_or(0),
1584 output_tokens: usage
1585 .get("completion_tokens")
1586 .and_then(|v| v.as_u64())
1587 .unwrap_or(0),
1588 cache_creation_input_tokens: None,
1589 cache_read_input_tokens: None,
1590 };
1591 }
1592 }
1593
1594 if let Some(ref cb) = on_event {
1596 cb(AgentEvent::ContentBlockStart {
1597 index: 0,
1598 block_type: "text".to_string(),
1599 });
1600 if !result.content.is_empty() {
1601 cb(AgentEvent::ContentBlockDelta {
1602 index: 0,
1603 delta: ContentDelta::Text {
1604 text: result.content.clone(),
1605 },
1606 });
1607 }
1608 cb(AgentEvent::ContentBlockStop { index: 0 });
1609 cb(AgentEvent::MessageStop);
1610 }
1611
1612 Ok(result)
1613}
1614
1615async fn make_anthropic_streaming_request(
1618 client: &reqwest::Client,
1619 url: &str,
1620 api_key: &str,
1621 request_body: serde_json::Value,
1622 on_event: Option<Arc<dyn Fn(AgentEvent) + Send + Sync>>,
1623) -> Result<StreamingResult, AgentError> {
1624 use futures_util::stream::StreamExt;
1625
1626 let is_anthropic = url.contains("anthropic.com");
1628
1629 let response = if is_anthropic {
1630 client
1632 .post(url)
1633 .header("x-api-key", api_key)
1634 .header("anthropic-version", "2023-06-01")
1635 .header("Content-Type", "application/json")
1636 .header("Accept", "text/event-stream")
1637 .json(&request_body)
1638 .send()
1639 .await
1640 .map_err(|e| AgentError::Api(format!("Streaming request failed: {}", e)))?
1641 } else {
1642 client
1644 .post(url)
1645 .header("Authorization", format!("Bearer {}", api_key))
1646 .header("Content-Type", "application/json")
1647 .header("Accept", "text/event-stream")
1648 .json(&request_body)
1649 .send()
1650 .await
1651 .map_err(|e| AgentError::Api(format!("Streaming request failed: {}", e)))?
1652 };
1653
1654 let status = response.status();
1655 if !status.is_success() {
1656 let error_text = response.text().await.unwrap_or_default();
1657 return Err(AgentError::Api(format!(
1658 "Streaming API error {}: {}",
1659 status, error_text
1660 )));
1661 }
1662
1663 if let Some(ref cb) = on_event {
1665 cb(AgentEvent::MessageStart {
1666 message_id: uuid::Uuid::new_v4().to_string(),
1667 });
1668 }
1669
1670 let body = response.bytes_stream();
1672 let mut stream: futures_util::stream::BoxStream<'_, Result<bytes::Bytes, reqwest::Error>> =
1673 Box::pin(body);
1674
1675 let mut result = StreamingResult::default();
1676 let mut current_tool_use: Option<(String, String, String)> = None; let mut content_index: u32 = 0;
1678 let mut tool_use_index: u32 = 0;
1679 let mut in_tool_use = false;
1680 let mut text_block_started = false;
1681
1682 while let Some(chunk_result) = stream.next().await {
1684 let chunk =
1685 chunk_result.map_err(|e| AgentError::Api(format!("Stream read error: {}", e)))?;
1686
1687 if let Ok(text) = String::from_utf8(chunk.to_vec()) {
1689 if !text.starts_with("data: ") {
1691 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
1694 if json.get("content").is_some() && json.get("choices").is_none() {
1697 if let Some(content_array) = json.get("content").and_then(|c| c.as_array())
1699 {
1700 for block in content_array {
1701 let block_type = block.get("type").and_then(|t| t.as_str());
1702 match block_type {
1703 Some("text") => {
1704 if let Some(text) =
1705 block.get("text").and_then(|t| t.as_str())
1706 {
1707 result.content.push_str(text);
1708 }
1709 }
1710 Some("tool_use") => {
1711 let tool_id =
1712 block.get("id").and_then(|i| i.as_str()).unwrap_or("");
1713 let tool_name = block
1714 .get("name")
1715 .and_then(|n| n.as_str())
1716 .unwrap_or("");
1717 let tool_input = block
1718 .get("input")
1719 .cloned()
1720 .unwrap_or(serde_json::Value::Null);
1721 result.tool_calls.push(serde_json::json!({
1722 "id": tool_id,
1723 "name": tool_name,
1724 "arguments": tool_input,
1725 }));
1726 }
1727 _ => {}
1728 }
1729 }
1730 if let Some(usage) = json.get("usage") {
1732 result.usage = parse_anthropic_usage(usage);
1733 }
1734 if let Some(ref cb) = on_event {
1736 cb(AgentEvent::MessageStart {
1737 message_id: json
1738 .get("id")
1739 .and_then(|i| i.as_str())
1740 .unwrap_or("")
1741 .to_string(),
1742 });
1743 cb(AgentEvent::ContentBlockStart {
1744 index: 0,
1745 block_type: "text".to_string(),
1746 });
1747 if !result.content.is_empty() {
1748 cb(AgentEvent::ContentBlockDelta {
1749 index: 0,
1750 delta: ContentDelta::Text {
1751 text: result.content.clone(),
1752 },
1753 });
1754 }
1755 cb(AgentEvent::ContentBlockStop { index: 0 });
1756 cb(AgentEvent::MessageStop);
1757 }
1758 return Ok(result);
1759 }
1760 if let Some(content) = json.get("content").and_then(|c| c.as_str()) {
1762 result.content.push_str(content);
1763 }
1764 if let Some(stop_reason) = json.get("stop_reason") {
1766 if !stop_reason.is_null() {
1767 if let Some(ref cb) = on_event {
1768 cb(AgentEvent::ContentBlockStart {
1769 index: 0,
1770 block_type: "text".to_string(),
1771 });
1772 if !result.content.is_empty() {
1773 cb(AgentEvent::ContentBlockDelta {
1774 index: 0,
1775 delta: ContentDelta::Text {
1776 text: result.content.clone(),
1777 },
1778 });
1779 }
1780 cb(AgentEvent::ContentBlockStop { index: 0 });
1781 cb(AgentEvent::MessageStop);
1782 }
1783 return Ok(result);
1784 }
1785 }
1786 continue;
1787 }
1788
1789 if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
1791 if let Some(first) = choices.first() {
1792 if let Some(delta) = first.get("delta") {
1793 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
1795 {
1796 result.content.push_str(content);
1797 }
1798 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
1800 for tc in tool_calls {
1801 let id = tc.get("id").and_then(|i| i.as_str()).unwrap_or("");
1802 let func = tc.get("function");
1803 let name = func.and_then(|f| f.get("name")).and_then(|n| n.as_str()).unwrap_or("");
1804 let args = func.and_then(|f| f.get("arguments"));
1805 let args_val = if let Some(args_str) = args.and_then(|a| a.as_str()) {
1806 serde_json::from_str(args_str).unwrap_or(serde_json::Value::Null)
1807 } else {
1808 args.cloned().unwrap_or(serde_json::Value::Null)
1809 };
1810 result.tool_calls.push(serde_json::json!({
1812 "id": id,
1813 "name": name,
1814 "arguments": args_val,
1815 }));
1816 }
1817 }
1818 }
1819 if let Some(finish_reason) =
1822 first.get("finish_reason").and_then(|f| f.as_str())
1823 {
1824 if !finish_reason.is_empty()
1825 && finish_reason != "null"
1826 && (!result.content.is_empty() || !result.tool_calls.is_empty())
1827 {
1828 if let Some(ref cb) = on_event {
1829 cb(AgentEvent::ContentBlockStop { index: 0 });
1830 cb(AgentEvent::MessageStop);
1831 }
1832 return Ok(result);
1833 }
1834 }
1835 }
1836 continue;
1837 }
1838
1839 if json.get("choices").is_some() {
1841 if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
1844 if let Some(first) = choices.first() {
1845 if let Some(msg) = first.get("message") {
1846 if let Some(content) =
1848 msg.get("content").and_then(|c| c.as_str())
1849 {
1850 result.content = content.to_string();
1851 }
1852 if let Some(tool_calls) =
1854 msg.get("tool_calls").and_then(|t| t.as_array())
1855 {
1856 for tc in tool_calls {
1857 let id =
1858 tc.get("id").and_then(|i| i.as_str()).unwrap_or("");
1859 let func = tc.get("function");
1860 let name = func
1861 .and_then(|f| f.get("name"))
1862 .and_then(|n| n.as_str())
1863 .unwrap_or("");
1864 let args = func.and_then(|f| f.get("arguments"));
1865 let args_val = if let Some(args_str) =
1866 args.and_then(|a| a.as_str())
1867 {
1868 serde_json::from_str(args_str)
1869 .unwrap_or(serde_json::Value::Null)
1870 } else {
1871 args.cloned().unwrap_or(serde_json::Value::Null)
1872 };
1873 result.tool_calls.push(serde_json::json!({
1874 "id": id,
1875 "name": name,
1876 "arguments": args_val,
1877 }));
1878 }
1879 }
1880 }
1881 }
1882 }
1883 if let Some(usage) = json.get("usage") {
1885 result.usage = TokenUsage {
1886 input_tokens: usage
1887 .get("prompt_tokens")
1888 .and_then(|v| v.as_u64())
1889 .unwrap_or(0),
1890 output_tokens: usage
1891 .get("completion_tokens")
1892 .and_then(|v| v.as_u64())
1893 .unwrap_or(0),
1894 cache_creation_input_tokens: None,
1895 cache_read_input_tokens: None,
1896 };
1897 }
1898 if let Some(ref cb) = on_event {
1900 cb(AgentEvent::ContentBlockStart {
1901 index: 0,
1902 block_type: "text".to_string(),
1903 });
1904 if !result.content.is_empty() {
1905 cb(AgentEvent::ContentBlockDelta {
1906 index: 0,
1907 delta: ContentDelta::Text {
1908 text: result.content.clone(),
1909 },
1910 });
1911 }
1912 cb(AgentEvent::ContentBlockStop { index: 0 });
1913 cb(AgentEvent::MessageStop);
1914 }
1915 return Ok(result);
1916 }
1917 }
1918 continue;
1919 }
1920
1921 for line in text.lines() {
1923 if line.starts_with("data: ") {
1924 let data = &line[6..];
1925
1926 if data == "[DONE]" {
1928 continue;
1929 }
1930
1931 if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
1933 if let Some(event_type) = json.get("type").and_then(|t| t.as_str()) {
1935 match event_type {
1936 "message_start" => {
1937 if let Some(usage) = json.get("usage") {
1939 result.usage = parse_anthropic_usage(usage);
1940 }
1941 }
1942 "content_block_start" => {
1943 let index =
1944 json.get("index").and_then(|i| i.as_u64()).unwrap_or(0)
1945 as u32;
1946 let block_type = json
1947 .get("content_block")
1948 .and_then(|b| b.get("type"))
1949 .and_then(|t| t.as_str())
1950 .unwrap_or("text")
1951 .to_string();
1952
1953 if block_type == "tool_use" {
1954 tool_use_index = index;
1955 in_tool_use = true;
1956 let tool_name = json
1957 .get("content_block")
1958 .and_then(|b| b.get("name"))
1959 .and_then(|n| n.as_str())
1960 .unwrap_or("")
1961 .to_string();
1962 let tool_id = json
1963 .get("content_block")
1964 .and_then(|b| b.get("id"))
1965 .and_then(|i| i.as_str())
1966 .unwrap_or("")
1967 .to_string();
1968 current_tool_use =
1969 Some((tool_id, tool_name, String::new()));
1970 } else {
1971 content_index = index;
1972 text_block_started = true;
1973 }
1974
1975 if let Some(ref cb) = on_event {
1976 cb(AgentEvent::ContentBlockStart { index, block_type });
1977 }
1978 }
1979 "content_block_delta" => {
1980 let index =
1981 json.get("index").and_then(|i| i.as_u64()).unwrap_or(0)
1982 as u32;
1983 if let Some(delta) = json.get("delta") {
1984 let delta_type = delta.get("type").and_then(|t| t.as_str());
1985
1986 match delta_type {
1987 Some("text_delta") => {
1988 if let Some(text) =
1989 delta.get("text").and_then(|t| t.as_str())
1990 {
1991 result.content.push_str(text);
1992
1993 if let Some(ref cb) = on_event {
1994 cb(AgentEvent::ContentBlockDelta {
1995 index,
1996 delta: ContentDelta::Text {
1997 text: text.to_string(),
1998 },
1999 });
2000 }
2001 }
2002 }
2003 Some("input_json_delta") => {
2004 let partial_json = delta
2005 .get("partial_json")
2006 .and_then(|p| p.as_str())
2007 .unwrap_or("");
2008
2009 if let Some(ref mut current) = current_tool_use {
2010 current.2.push_str(partial_json);
2011 }
2012
2013 if let Some(ref cb) = on_event {
2014 let tool_name = current_tool_use
2015 .as_ref()
2016 .map(|(_, n, _)| n.clone())
2017 .unwrap_or_default();
2018 let tool_id = current_tool_use
2019 .as_ref()
2020 .map(|(i, _, _)| i.clone())
2021 .unwrap_or_default();
2022 cb(AgentEvent::ContentBlockDelta {
2023 index,
2024 delta: ContentDelta::ToolUse {
2025 id: tool_id,
2026 name: tool_name,
2027 input: serde_json::json!({ "partial": partial_json }),
2028 is_complete: false,
2029 },
2030 });
2031 }
2032 }
2033 _ => {}
2034 }
2035 }
2036 }
2037 "content_block_stop" => {
2038 let index =
2039 json.get("index").and_then(|i| i.as_u64()).unwrap_or(0)
2040 as u32;
2041
2042 if in_tool_use && index == tool_use_index {
2044 if let Some((id, name, args_str)) = current_tool_use.take()
2045 {
2046 let args: serde_json::Value =
2047 serde_json::from_str(&args_str)
2048 .unwrap_or(serde_json::Value::Null);
2049
2050 result.tool_calls.push(serde_json::json!({
2051 "id": id,
2052 "name": name,
2053 "arguments": args,
2054 }));
2055 }
2056 in_tool_use = false;
2057 }
2058
2059 if let Some(ref cb) = on_event {
2060 cb(AgentEvent::ContentBlockStop { index });
2061 }
2062 }
2063 "message_delta" => {
2064 if let Some(usage) = json.get("usage") {
2066 result.usage = parse_anthropic_usage(usage);
2067 }
2068 }
2069 "message_stop" => {
2070 }
2072 _ => {}
2073 }
2074 }
2075
2076 if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
2079 if let Some(first) = choices.first() {
2080 if let Some(delta) = first.get("delta") {
2081 if let Some(content) =
2083 delta.get("content").and_then(|c| c.as_str())
2084 {
2085 if !content.is_empty() {
2086 result.content.push_str(content);
2087
2088 if let Some(ref cb) = on_event {
2089 cb(AgentEvent::ContentBlockDelta {
2090 index: 0,
2091 delta: ContentDelta::Text {
2092 text: content.to_string(),
2093 },
2094 });
2095 }
2096 }
2097 }
2098 if let Some(tool_calls) =
2100 delta.get("tool_calls").and_then(|t| t.as_array())
2101 {
2102 for tc in tool_calls {
2103 let id = tc
2104 .get("id")
2105 .and_then(|i| i.as_str())
2106 .unwrap_or("");
2107 let func = tc.get("function");
2108 let name = func
2109 .and_then(|f| f.get("name"))
2110 .and_then(|n| n.as_str())
2111 .unwrap_or("");
2112 let args = func.and_then(|f| f.get("arguments"));
2113 let args_val =
2114 if let Some(args_str) = args.and_then(|a| a.as_str())
2115 {
2116 serde_json::from_str(args_str)
2117 .unwrap_or(serde_json::Value::Null)
2118 } else {
2119 args.cloned()
2120 .unwrap_or(serde_json::Value::Null)
2121 };
2122 result.tool_calls.push(serde_json::json!({
2123 "id": id,
2124 "name": name,
2125 "arguments": args_val,
2126 }));
2127 }
2128 }
2129 }
2130 if let Some(finish_reason) =
2132 first.get("finish_reason").and_then(|f| f.as_str())
2133 {
2134 if !finish_reason.is_empty() && finish_reason != "null" {
2135 if let Some(ref cb) = on_event {
2136 cb(AgentEvent::ContentBlockStop { index: 0 });
2137 cb(AgentEvent::MessageStop);
2138 }
2139 return Ok(result);
2140 }
2141 }
2142 }
2143 continue;
2145 }
2146
2147 if json.get("content").is_some() || json.get("id").is_some() {
2149 if let Some(content_array) =
2151 json.get("content").and_then(|c| c.as_array())
2152 {
2153 for block in content_array {
2154 let block_type = block.get("type").and_then(|t| t.as_str());
2155 match block_type {
2156 Some("text") => {
2157 if let Some(text) =
2158 block.get("text").and_then(|t| t.as_str())
2159 {
2160 result.content.push_str(text);
2161 }
2162 }
2163 Some("tool_use") => {
2164 let tool_id = block
2165 .get("id")
2166 .and_then(|i| i.as_str())
2167 .unwrap_or("");
2168 let tool_name = block
2169 .get("name")
2170 .and_then(|n| n.as_str())
2171 .unwrap_or("");
2172 let tool_input = block
2173 .get("input")
2174 .cloned()
2175 .unwrap_or(serde_json::Value::Null);
2176
2177 result.tool_calls.push(serde_json::json!({
2178 "id": tool_id,
2179 "name": tool_name,
2180 "arguments": tool_input,
2181 }));
2182 }
2183 _ => {}
2184 }
2185 }
2186 }
2187
2188 if let Some(usage) = json.get("usage") {
2190 result.usage = parse_anthropic_usage(usage);
2191 }
2192
2193 if let Some(ref cb) = on_event {
2195 cb(AgentEvent::ContentBlockStart {
2196 index: 0,
2197 block_type: "text".to_string(),
2198 });
2199 if !result.content.is_empty() {
2200 cb(AgentEvent::ContentBlockDelta {
2201 index: 0,
2202 delta: ContentDelta::Text {
2203 text: result.content.clone(),
2204 },
2205 });
2206 }
2207 cb(AgentEvent::ContentBlockStop { index: 0 });
2208 cb(AgentEvent::MessageStop);
2209 }
2210 return Ok(result);
2211 }
2212 }
2213 }
2214 }
2215 }
2216 }
2217
2218 if let Some(ref cb) = on_event {
2220 cb(AgentEvent::MessageStop);
2221 }
2222
2223 Ok(result)
2224}
2225
2226#[cfg(test)]
2227#[allow(unused_imports)]
2228mod tests {
2229 use super::*;
2230
2231 #[tokio::test]
2232 async fn test_engine_creation() {
2233 let engine = QueryEngine::new(QueryEngineConfig {
2234 cwd: "/tmp".to_string(),
2235 model: "claude-sonnet-4-6".to_string(),
2236 api_key: None,
2237 base_url: None,
2238 tools: vec![],
2239 system_prompt: None,
2240 max_turns: 10,
2241 max_budget_usd: None,
2242 max_tokens: 16384,
2243 can_use_tool: None,
2244 on_event: None,
2245 });
2246 assert_eq!(engine.get_turn_count(), 0);
2247 }
2248
2249 #[tokio::test]
2250 async fn test_engine_submit_message() {
2251 let mut engine = QueryEngine::new(QueryEngineConfig {
2252 cwd: "/tmp".to_string(),
2253 model: "claude-sonnet-4-6".to_string(),
2254 api_key: None,
2255 base_url: None,
2256 tools: vec![],
2257 system_prompt: None,
2258 max_turns: 10,
2259 max_budget_usd: None,
2260 max_tokens: 16384,
2261 can_use_tool: None,
2262 on_event: None,
2263 });
2264
2265 let result = engine.submit_message("Hello").await;
2266 assert!(result.is_err());
2268 }
2269
2270 #[test]
2271 fn test_strip_thinking() {
2272 let content =
2274 "<think>I should list the files here.</think>Here are the files: file1.txt, file2.txt";
2275 let result = strip_thinking(content);
2276 assert_eq!(result, "Here are the files: file1.txt, file2.txt");
2277
2278 let content2 = "Hello world";
2280 let result2 = strip_thinking(content2);
2281 assert_eq!(result2, "Hello world");
2282
2283 let content3 = "<think>Thinking...</think>";
2285 let result3 = strip_thinking(content3);
2286 assert_eq!(result3, "");
2287
2288 let content4 = "<think>First think</think>Hello<think>Second think</think>World";
2290 let result4 = strip_thinking(content4);
2291 assert_eq!(result4, "HelloWorld");
2292 }
2293
2294 #[test]
2295 fn test_strip_thinking_utf8() {
2296 let content = "<think>思考</think>Hello → World";
2298 let result = strip_thinking(content);
2299 assert_eq!(result, "Hello → World");
2300
2301 let content2 = "<think>中文</think>你好世界";
2303 let result2 = strip_thinking(content2);
2304 assert_eq!(result2, "你好世界");
2305
2306 let content3 = "<think>thinking emoji 🎭</think>Hello 👋 World";
2308 let result3 = strip_thinking(content3);
2309 assert_eq!(result3, "Hello 👋 World");
2310
2311 let content4 = "<think>The → symbol is here</think>Result: 你好 🎉";
2313 let result4 = strip_thinking(content4);
2314 assert_eq!(result4, "Result: 你好 🎉");
2315
2316 let content5 = "<think>thinking开始啦</think>继续内容";
2318 let result5 = strip_thinking(content5);
2319 assert_eq!(result5, "继续内容");
2320
2321 let content6 = "开始内容<think>thinking结束啦</think>";
2323 let result6 = strip_thinking(content6);
2324 assert_eq!(result6, "开始内容");
2325
2326 let content7 = "<think>第一步思考→思考第二步</think>执行→完成";
2328 let result7 = strip_thinking(content7);
2329 assert_eq!(result7, "执行→完成");
2330 }
2331
2332 #[test]
2333 fn test_fallback_tool_call_extraction() {
2334 use serde_json::json;
2336
2337 let response = json!({
2339 "choices": [
2340 {
2341 "message": {
2342 "content": null,
2343 "tool_calls": [
2344 {
2345 "id": "call_123",
2346 "type": "function",
2347 "function": {
2348 "name": "Bash",
2349 "arguments": "{\"command\": \"ls -la\"}"
2350 }
2351 }
2352 ]
2353 },
2354 "finish_reason": "tool_calls"
2355 }
2356 ],
2357 "usage": {
2358 "prompt_tokens": 100,
2359 "completion_tokens": 50
2360 }
2361 });
2362
2363 let mut tool_calls = Vec::new();
2365 if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
2366 if let Some(first) = choices.first() {
2367 if let Some(msg) = first.get("message") {
2368 if let Some(tc_array) = msg.get("tool_calls").and_then(|t| t.as_array()) {
2369 for tc in tc_array {
2370 let id = tc.get("id").and_then(|i| i.as_str()).unwrap_or("");
2371 let func = tc.get("function");
2372 let name = func
2373 .and_then(|f| f.get("name"))
2374 .and_then(|n| n.as_str())
2375 .unwrap_or("");
2376 let args = func.and_then(|f| f.get("arguments"));
2377 let args_val = if let Some(args_str) = args.and_then(|a| a.as_str()) {
2378 serde_json::from_str(args_str).unwrap_or(serde_json::Value::Null)
2379 } else {
2380 args.cloned().unwrap_or(serde_json::Value::Null)
2381 };
2382 tool_calls.push(serde_json::json!({
2383 "id": id,
2384 "name": name,
2385 "arguments": args_val,
2386 }));
2387 }
2388 }
2389 }
2390 }
2391 }
2392
2393 assert_eq!(tool_calls.len(), 1);
2394 assert_eq!(tool_calls[0]["name"], "Bash");
2395 assert_eq!(tool_calls[0]["id"], "call_123");
2396 }
2397
2398 #[test]
2399 fn test_streaming_tool_call_extraction() {
2400 use serde_json::json;
2402
2403 let chunk = json!({
2405 "choices": [
2406 {
2407 "delta": {
2408 "tool_calls": [
2409 {
2410 "id": "call_456",
2411 "type": "function",
2412 "function": {
2413 "name": "Read",
2414 "arguments": "{\"file_path\": \"/tmp/test\"}"
2415 }
2416 }
2417 ]
2418 },
2419 "finish_reason": "tool_calls"
2420 }
2421 ]
2422 });
2423
2424 let tool_calls = chunk
2426 .get("choices")
2427 .and_then(|c| c.as_array())
2428 .and_then(|choices| choices.first())
2429 .and_then(|choice| choice.get("delta"))
2430 .and_then(|delta| delta.get("tool_calls"))
2431 .and_then(|tc| tc.as_array());
2432
2433 assert!(tool_calls.is_some());
2434 let tc = tool_calls.unwrap().first().unwrap();
2435 assert_eq!(tc.get("id").and_then(|i| i.as_str()), Some("call_456"));
2436 assert_eq!(
2437 tc.get("function")
2438 .and_then(|f| f.get("name"))
2439 .and_then(|n| n.as_str()),
2440 Some("Read")
2441 );
2442 }
2443
2444 #[test]
2449 fn test_tool_definition_serialization() {
2450 use crate::tools::get_all_base_tools;
2451 use serde_json::json;
2452
2453 let tools = get_all_base_tools();
2454 assert!(!tools.is_empty());
2455
2456 for tool in &tools {
2458 let tool_json = json!({
2459 "type": "function",
2460 "function": {
2461 "name": tool.name,
2462 "description": tool.description,
2463 "parameters": tool.input_schema
2464 }
2465 });
2466
2467 assert!(tool_json.get("type").is_some());
2469 assert!(tool_json.get("function").is_some());
2470 let func = tool_json.get("function").unwrap();
2471 assert!(func.get("name").is_some());
2472 assert!(func.get("description").is_some());
2473 assert!(func.get("parameters").is_some());
2474
2475 let name = func.get("name").unwrap().as_str().unwrap();
2477 assert!(!name.is_empty());
2478 }
2479 }
2480
2481 #[test]
2482 fn test_tool_call_parsing() {
2483 use crate::types::{Message, MessageRole, ToolCall};
2484
2485 let tool_calls = vec![
2487 ToolCall {
2488 id: "call_abc123".to_string(),
2489 name: "Bash".to_string(),
2490 arguments: serde_json::json!({"command": "ls -la"}),
2491 },
2492 ToolCall {
2493 id: "call_def456".to_string(),
2494 name: "Read".to_string(),
2495 arguments: serde_json::json!({"path": "/tmp/test.txt"}),
2496 },
2497 ];
2498
2499 assert_eq!(tool_calls.len(), 2);
2501 assert_eq!(tool_calls[0].id, "call_abc123");
2502 assert_eq!(tool_calls[0].name, "Bash");
2503 assert_eq!(tool_calls[1].id, "call_def456");
2504 assert_eq!(tool_calls[1].name, "Read");
2505 }
2506
2507 #[test]
2508 fn test_tool_result_message_format() {
2509 use crate::types::{Message, MessageRole};
2510
2511 let msg = Message {
2513 role: MessageRole::Tool,
2514 content: "file content here".to_string(),
2515 tool_call_id: Some("call_abc123".to_string()),
2516 is_error: Some(false),
2517 ..Default::default()
2518 };
2519
2520 assert_eq!(msg.role, MessageRole::Tool);
2521 assert_eq!(msg.tool_call_id, Some("call_abc123".to_string()));
2522 assert_eq!(msg.is_error, Some(false));
2523 }
2524
2525 #[test]
2526 fn test_tool_execution_context() {
2527 use crate::types::ToolContext;
2528
2529 let ctx = ToolContext {
2530 cwd: "/tmp/test".to_string(),
2531 abort_signal: None,
2532 };
2533
2534 assert_eq!(ctx.cwd, "/tmp/test");
2535 }
2536
2537 #[test]
2538 fn test_base_tools_available() {
2539 use crate::tools::get_all_base_tools;
2540
2541 let tools = get_all_base_tools();
2542
2543 let tool_names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
2545
2546 assert!(tool_names.contains(&"Bash"), "Bash tool must be available");
2548
2549 assert!(
2551 tool_names.contains(&"FileRead"),
2552 "FileRead tool must be available"
2553 );
2554
2555 assert!(
2557 tool_names.contains(&"FileWrite"),
2558 "FileWrite tool must be available"
2559 );
2560
2561 assert!(tool_names.contains(&"Glob"), "Glob tool must be available");
2563
2564 assert!(tool_names.contains(&"Grep"), "Grep tool must be available");
2566
2567 assert!(
2569 tool_names.contains(&"FileEdit"),
2570 "FileEdit tool must be available"
2571 );
2572 }
2573
2574 #[test]
2575 fn test_tool_schemas_have_required_fields() {
2576 use crate::tools::get_all_base_tools;
2577
2578 let tools = get_all_base_tools();
2579
2580 for tool in &tools {
2581 assert!(!tool.name.is_empty(), "Tool {} has empty name", tool.name);
2583
2584 assert!(
2586 !tool.description.is_empty(),
2587 "Tool {} has empty description",
2588 tool.name
2589 );
2590
2591 let schema = &tool.input_schema;
2593 assert!(
2594 !schema.schema_type.is_empty(),
2595 "Tool {} has empty schema_type",
2596 tool.name
2597 );
2598 assert!(
2599 schema.properties.is_object(),
2600 "Tool {} has non-object properties",
2601 tool.name
2602 );
2603 }
2604 }
2605
2606 #[test]
2607 fn test_tool_schema_has_required_parameters() {
2608 use crate::tools::get_all_base_tools;
2609
2610 let tools = get_all_base_tools();
2611
2612 let bash_tool = tools.iter().find(|t| t.name == "Bash").unwrap();
2614 let props = &bash_tool.input_schema.properties;
2615 assert!(
2616 props.get("command").is_some(),
2617 "Bash tool must have 'command' parameter"
2618 );
2619
2620 let read_tool = tools.iter().find(|t| t.name == "FileRead").unwrap();
2622 let read_props = &read_tool.input_schema.properties;
2623 assert!(
2624 read_props.get("path").is_some(),
2625 "FileRead tool must have 'path' parameter"
2626 );
2627
2628 let write_tool = tools.iter().find(|t| t.name == "FileWrite").unwrap();
2630 let write_props = &write_tool.input_schema.properties;
2631 assert!(
2632 write_props.get("path").is_some(),
2633 "FileWrite tool must have 'path' parameter"
2634 );
2635 assert!(
2636 write_props.get("content").is_some(),
2637 "FileWrite tool must have 'content' parameter"
2638 );
2639
2640 assert!(
2642 bash_tool.input_schema.required.is_some(),
2643 "Bash tool must have required parameters"
2644 );
2645 }
2646
2647 #[tokio::test]
2648 async fn test_engine_with_tools_config() {
2649 use crate::tools::get_all_base_tools;
2650
2651 let tools = get_all_base_tools();
2652
2653 let engine = QueryEngine::new(QueryEngineConfig {
2654 cwd: "/tmp".to_string(),
2655 model: "claude-sonnet-4-6".to_string(),
2656 api_key: None,
2657 base_url: None,
2658 tools: tools.clone(),
2659 system_prompt: Some("You are a helpful assistant.".to_string()),
2660 max_turns: 10,
2661 max_budget_usd: None,
2662 max_tokens: 16384,
2663 can_use_tool: None,
2664 on_event: None,
2665 });
2666
2667 assert!(!engine.config.tools.is_empty());
2669 }
2670
2671 #[tokio::test]
2672 async fn test_engine_system_prompt_includes_tool_guidance() {
2673 let engine = QueryEngine::new(QueryEngineConfig {
2675 cwd: "/tmp".to_string(),
2676 model: "claude-sonnet-4-6".to_string(),
2677 api_key: None,
2678 base_url: None,
2679 tools: vec![],
2680 system_prompt: Some("You are an agent that helps users with software engineering tasks. Use the tools available to you to assist the user.".to_string()),
2681 max_turns: 10,
2682 max_budget_usd: None,
2683 max_tokens: 16384,
2684 can_use_tool: None,
2685 on_event: None,
2686 });
2687
2688 assert!(engine.config.system_prompt.is_some());
2690 let prompt = engine.config.system_prompt.as_ref().unwrap();
2691 assert!(prompt.contains("tools"));
2692 }
2693
2694 #[test]
2695 fn test_tool_call_arguments_json() {
2696 use crate::types::ToolCall;
2697
2698 let tc = ToolCall {
2700 id: "call_test".to_string(),
2701 name: "Bash".to_string(),
2702 arguments: serde_json::json!({
2703 "command": "echo hello"
2704 }),
2705 };
2706
2707 let args_str = tc.arguments.to_string();
2709 assert!(!args_str.is_empty());
2710
2711 let parsed: serde_json::Value = serde_json::from_str(&args_str).unwrap();
2713 assert_eq!(
2714 parsed.get("command").and_then(|v| v.as_str()),
2715 Some("echo hello")
2716 );
2717 }
2718
2719 #[test]
2720 fn test_build_api_messages_includes_tools_info() {
2721 let system_prompt = "You are an agent. Use the tools available to you: Bash, Read, Write, Glob, Grep, Edit.";
2723
2724 assert!(system_prompt.contains("tools"));
2726 assert!(system_prompt.contains("Bash"));
2727 }
2728
2729 #[tokio::test]
2730 async fn test_query_engine_tool_registration() {
2731 use crate::tools::get_all_base_tools;
2732
2733 let tools = get_all_base_tools();
2734 let tool_names: Vec<String> = tools.iter().map(|t| t.name.clone()).collect();
2735
2736 assert!(tool_names.len() >= 10, "Should have at least 10 tools");
2738
2739 assert!(tool_names.contains(&"Bash".to_string()));
2741 assert!(tool_names.contains(&"FileRead".to_string()));
2742 assert!(tool_names.contains(&"FileWrite".to_string()));
2743 assert!(tool_names.contains(&"Glob".to_string()));
2744 assert!(tool_names.contains(&"Grep".to_string()));
2745 assert!(tool_names.contains(&"FileEdit".to_string()));
2746 }
2747
2748 #[test]
2749 fn test_openai_tool_format_compatibility() {
2750 use crate::tools::get_all_base_tools;
2751 use serde_json::json;
2752
2753 let tools = get_all_base_tools();
2755 let bash_tool = tools.iter().find(|t| t.name == "Bash").unwrap();
2756
2757 let openai_format = json!({
2758 "type": "function",
2759 "function": {
2760 "name": bash_tool.name,
2761 "description": bash_tool.description,
2762 "parameters": bash_tool.input_schema
2763 }
2764 });
2765
2766 assert_eq!(openai_format.get("type").unwrap(), "function");
2768 let func = openai_format.get("function").unwrap();
2769 assert!(func.get("name").is_some());
2770 assert!(func.get("description").is_some());
2771 assert!(func.get("parameters").is_some());
2772
2773 let json_str = openai_format.to_string();
2775 assert!(!json_str.is_empty());
2776
2777 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
2779 assert_eq!(parsed.get("type").unwrap(), "function");
2780 }
2781
2782 #[tokio::test]
2783 async fn test_engine_message_history_with_tool_calls() {
2784 use crate::types::{Message, MessageRole, ToolCall};
2785
2786 let mut engine = QueryEngine::new(QueryEngineConfig {
2787 cwd: "/tmp".to_string(),
2788 model: "claude-sonnet-4-6".to_string(),
2789 api_key: None,
2790 base_url: None,
2791 tools: vec![],
2792 system_prompt: None,
2793 max_turns: 10,
2794 max_budget_usd: None,
2795 max_tokens: 16384,
2796 can_use_tool: None,
2797 on_event: None,
2798 });
2799
2800 engine.messages.push(Message {
2802 role: MessageRole::User,
2803 content: "List files in /tmp".to_string(),
2804 ..Default::default()
2805 });
2806
2807 engine.messages.push(Message {
2809 role: MessageRole::Assistant,
2810 content: "".to_string(),
2811 tool_calls: Some(vec![ToolCall {
2812 id: "call_123".to_string(),
2813 name: "Bash".to_string(),
2814 arguments: serde_json::json!({"command": "ls /tmp"}),
2815 }]),
2816 ..Default::default()
2817 });
2818
2819 engine.messages.push(Message {
2821 role: MessageRole::Tool,
2822 content: "file1.txt\nfile2.txt".to_string(),
2823 tool_call_id: Some("call_123".to_string()),
2824 ..Default::default()
2825 });
2826
2827 assert_eq!(engine.messages.len(), 3);
2829 assert_eq!(engine.messages[1].role, MessageRole::Assistant);
2830 assert!(engine.messages[1].tool_calls.is_some());
2831 assert_eq!(engine.messages[2].role, MessageRole::Tool);
2832 assert_eq!(
2833 engine.messages[2].tool_call_id,
2834 Some("call_123".to_string())
2835 );
2836 }
2837
2838 #[test]
2839 fn test_tool_result_error_handling() {
2840 use crate::types::{Message, MessageRole};
2841
2842 let error_msg = Message {
2844 role: MessageRole::Tool,
2845 content: "Error: Permission denied".to_string(),
2846 tool_call_id: Some("call_err".to_string()),
2847 is_error: Some(true),
2848 ..Default::default()
2849 };
2850
2851 assert_eq!(error_msg.is_error, Some(true));
2852 assert!(error_msg.content.contains("Error"));
2853 }
2854}