1use anyhow::Result;
2use std::sync::Arc;
3
4use async_stream::try_stream;
5use futures::stream::StreamExt;
6use serde_json::{json, Value};
7use tracing::debug;
8
9use super::super::agents::Agent;
10use crate::conversation::message::{Message, MessageContent, ToolRequest};
11use crate::conversation::Conversation;
12use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage};
13use crate::providers::errors::ProviderError;
14use crate::providers::toolshim::{
15 augment_message_with_tool_calls, convert_tool_messages_to_text,
16 modify_system_prompt_for_tool_json, OllamaInterpreter,
17};
18
19use crate::agents::code_execution_extension::EXTENSION_NAME as CODE_EXECUTION_EXTENSION;
20#[cfg(test)]
21use crate::session::SessionType;
22use crate::session::{SessionManager, SessionStore, TokenStatsUpdate};
23use rmcp::model::Tool;
24
25fn coerce_value(s: &str, schema: &Value) -> Value {
26 let type_str = schema.get("type");
27
28 match type_str {
29 Some(Value::String(t)) => match t.as_str() {
30 "number" | "integer" => try_coerce_number(s),
31 "boolean" => try_coerce_boolean(s),
32 _ => Value::String(s.to_string()),
33 },
34 Some(Value::Array(types)) => {
35 for t in types {
37 if let Value::String(type_name) = t {
38 match type_name.as_str() {
39 "number" | "integer" if s.parse::<f64>().is_ok() => {
40 return try_coerce_number(s)
41 }
42 "boolean" if matches!(s.to_lowercase().as_str(), "true" | "false") => {
43 return try_coerce_boolean(s)
44 }
45 _ => continue,
46 }
47 }
48 }
49 Value::String(s.to_string())
50 }
51 _ => Value::String(s.to_string()),
52 }
53}
54
55fn try_coerce_number(s: &str) -> Value {
56 if let Ok(n) = s.parse::<f64>() {
57 if n.fract() == 0.0 && n >= i64::MIN as f64 && n <= i64::MAX as f64 {
58 json!(n as i64)
59 } else {
60 json!(n)
61 }
62 } else {
63 Value::String(s.to_string())
64 }
65}
66
67fn try_coerce_boolean(s: &str) -> Value {
68 match s.to_lowercase().as_str() {
69 "true" => json!(true),
70 "false" => json!(false),
71 _ => Value::String(s.to_string()),
72 }
73}
74
75fn coerce_tool_arguments(
76 arguments: Option<serde_json::Map<String, Value>>,
77 tool_schema: &Value,
78) -> Option<serde_json::Map<String, Value>> {
79 let args = arguments?;
80
81 let properties = tool_schema.get("properties").and_then(|p| p.as_object())?;
82
83 let mut coerced = serde_json::Map::new();
84
85 for (key, value) in args.iter() {
86 let coerced_value =
87 if let (Value::String(s), Some(prop_schema)) = (value, properties.get(key)) {
88 coerce_value(s, prop_schema)
89 } else {
90 value.clone()
91 };
92 coerced.insert(key.clone(), coerced_value);
93 }
94
95 Some(coerced)
96}
97
98async fn toolshim_postprocess(
99 response: Message,
100 toolshim_tools: &[Tool],
101) -> Result<Message, ProviderError> {
102 let interpreter = OllamaInterpreter::new().map_err(|e| {
103 ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
104 })?;
105
106 augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
107 .await
108 .map_err(|e| ProviderError::ExecutionError(format!("Failed to augment message: {}", e)))
109}
110
111impl Agent {
112 pub async fn prepare_tools_and_prompt(
113 &self,
114 working_dir: &std::path::Path,
115 session_prompt: Option<&str>,
116 ) -> Result<(Vec<Tool>, Vec<Tool>, String)> {
117 let mut tools = self.list_tools(None).await;
119
120 let frontend_tools = self.frontend_tools.lock().await;
122 for frontend_tool in frontend_tools.values() {
123 tools.push(frontend_tool.tool.clone());
124 }
125
126 let code_execution_active = self
127 .extension_manager
128 .is_extension_enabled(CODE_EXECUTION_EXTENSION)
129 .await;
130 if code_execution_active {
131 let code_exec_prefix = format!("{CODE_EXECUTION_EXTENSION}__");
132 tools.retain(|tool| tool.name.starts_with(&code_exec_prefix));
133 }
134
135 tools.sort_by(|a, b| a.name.cmp(&b.name));
137
138 let extensions_info = self.extension_manager.get_extensions_info().await;
140 let (extension_count, tool_count) =
141 self.extension_manager.get_extension_and_tool_counts().await;
142
143 let provider = self.provider().await?;
145 let model_config = provider.get_model_config();
146
147 let prompt_manager = self.prompt_manager.lock().await;
148 let mut system_prompt = prompt_manager
149 .builder()
150 .with_extensions(extensions_info.into_iter())
151 .with_frontend_instructions(self.frontend_instructions.lock().await.clone())
152 .with_extension_and_tool_counts(extension_count, tool_count)
153 .with_code_execution_mode(code_execution_active)
154 .with_hints(working_dir)
155 .with_enable_subagents(self.subagents_enabled().await)
156 .with_session_prompt(session_prompt.map(|s| s.to_string()))
157 .build();
158
159 let mut toolshim_tools = vec![];
161 if model_config.toolshim {
162 system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
164 toolshim_tools = tools.clone();
166 tools = vec![];
168 }
169
170 Ok((tools, toolshim_tools, system_prompt))
171 }
172
173 pub(crate) async fn stream_response_from_provider(
176 provider: Arc<dyn Provider>,
177 system_prompt: &str,
178 messages: &[Message],
179 tools: &[Tool],
180 toolshim_tools: &[Tool],
181 ) -> Result<MessageStream, ProviderError> {
182 let config = provider.get_model_config();
183
184 let messages_for_provider = if config.toolshim {
186 convert_tool_messages_to_text(messages)
187 } else {
188 Conversation::new_unvalidated(messages.to_vec())
189 };
190
191 let system_prompt = system_prompt.to_owned();
193 let tools = tools.to_owned();
194 let toolshim_tools = toolshim_tools.to_owned();
195 let provider = provider.clone();
196
197 let stream_result = if provider.supports_streaming() {
200 debug!("WAITING_LLM_STREAM_START");
201 let result = provider
202 .stream(
203 system_prompt.as_str(),
204 messages_for_provider.messages(),
205 &tools,
206 )
207 .await;
208 debug!("WAITING_LLM_STREAM_END");
209 result
210 } else {
211 debug!("WAITING_LLM_START");
212 let complete_result = provider
213 .complete(
214 system_prompt.as_str(),
215 messages_for_provider.messages(),
216 &tools,
217 )
218 .await;
219 debug!("WAITING_LLM_END");
220
221 match complete_result {
222 Ok((message, usage)) => Ok(stream_from_single_message(message, usage)),
223 Err(e) => Err(e),
224 }
225 };
226
227 let mut stream = match stream_result {
229 Ok(s) => s,
230 Err(e) => {
231 return Ok(Box::pin(try_stream! {
234 yield Err(e)?;
235 }));
236 }
237 };
238
239 Ok(Box::pin(try_stream! {
240 while let Some(Ok((mut message, usage))) = stream.next().await {
241 if let Some(usage) = usage.as_ref() {
243 crate::providers::base::set_current_model(&usage.model);
244 }
245
246 if message.is_some() && config.toolshim {
248 message = Some(toolshim_postprocess(message.unwrap(), &toolshim_tools).await?);
249 }
250
251 yield (message, usage);
252 }
253 }))
254 }
255
256 pub(crate) async fn categorize_tool_requests(
262 &self,
263 response: &Message,
264 tools: &[Tool],
265 ) -> (Vec<ToolRequest>, Vec<ToolRequest>, Message) {
266 let tool_requests: Vec<ToolRequest> = response
268 .content
269 .iter()
270 .filter_map(|content| {
271 if let MessageContent::ToolRequest(req) = content {
272 let mut coerced_req = req.clone();
273
274 if let Ok(ref mut tool_call) = coerced_req.tool_call {
275 if let Some(tool) = tools.iter().find(|t| t.name == tool_call.name) {
276 let schema_value = Value::Object(tool.input_schema.as_ref().clone());
277 tool_call.arguments =
278 coerce_tool_arguments(tool_call.arguments.clone(), &schema_value);
279
280 if let Some(ref meta) = tool.meta {
281 coerced_req.tool_meta = serde_json::to_value(meta).ok();
282 }
283 }
284 }
285
286 Some(coerced_req)
287 } else {
288 None
289 }
290 })
291 .collect();
292
293 let mut filtered_content = Vec::new();
295 let mut tool_request_index = 0;
296
297 for content in &response.content {
298 match content {
299 MessageContent::ToolRequest(_) => {
300 if tool_request_index < tool_requests.len() {
301 let coerced_req = &tool_requests[tool_request_index];
302 tool_request_index += 1;
303
304 let should_include = if let Ok(tool_call) = &coerced_req.tool_call {
305 !self.is_frontend_tool(&tool_call.name).await
306 } else {
307 true
308 };
309
310 if should_include {
311 filtered_content.push(MessageContent::ToolRequest(coerced_req.clone()));
312 }
313 }
314 }
315 _ => {
316 filtered_content.push(content.clone());
317 }
318 }
319 }
320
321 let mut filtered_message =
322 Message::new(response.role.clone(), response.created, filtered_content);
323
324 if let Some(id) = response.id.clone() {
326 filtered_message = filtered_message.with_id(id);
327 }
328
329 let mut frontend_requests = Vec::new();
331 let mut other_requests = Vec::new();
332
333 for request in tool_requests {
334 if let Ok(tool_call) = &request.tool_call {
335 if self.is_frontend_tool(&tool_call.name).await {
336 frontend_requests.push(request);
337 } else {
338 other_requests.push(request);
339 }
340 } else {
341 other_requests.push(request);
343 }
344 }
345
346 (frontend_requests, other_requests, filtered_message)
347 }
348
349 pub(crate) async fn update_session_metrics(
350 session_config: &crate::agents::types::SessionConfig,
351 usage: &ProviderUsage,
352 is_compaction_usage: bool,
353 session_store: Option<&Arc<dyn SessionStore>>,
354 ) -> Result<()> {
355 let session_id = session_config.id.as_str();
356 let session = if let Some(store) = session_store {
357 store.get_session(session_id, false).await?
358 } else {
359 SessionManager::get_session(session_id, false).await?
360 };
361
362 let accumulate = |a: Option<i32>, b: Option<i32>| -> Option<i32> {
363 match (a, b) {
364 (Some(x), Some(y)) => Some(x + y),
365 _ => a.or(b),
366 }
367 };
368
369 let accumulated_total =
370 accumulate(session.accumulated_total_tokens, usage.usage.total_tokens);
371 let accumulated_input =
372 accumulate(session.accumulated_input_tokens, usage.usage.input_tokens);
373 let accumulated_output =
374 accumulate(session.accumulated_output_tokens, usage.usage.output_tokens);
375
376 let (current_total, current_input, current_output) = if is_compaction_usage {
377 let new_input = usage.usage.output_tokens;
379 (new_input, new_input, None)
380 } else {
381 (
382 usage.usage.total_tokens,
383 usage.usage.input_tokens,
384 usage.usage.output_tokens,
385 )
386 };
387
388 if let Some(store) = session_store {
389 store
390 .update_token_stats(
391 session_id,
392 TokenStatsUpdate {
393 schedule_id: session_config.schedule_id.clone(),
394 total_tokens: current_total,
395 input_tokens: current_input,
396 output_tokens: current_output,
397 accumulated_total,
398 accumulated_input,
399 accumulated_output,
400 },
401 )
402 .await?;
403 } else {
404 SessionManager::update_session(session_id)
405 .schedule_id(session_config.schedule_id.clone())
406 .total_tokens(current_total)
407 .input_tokens(current_input)
408 .output_tokens(current_output)
409 .accumulated_total_tokens(accumulated_total)
410 .accumulated_input_tokens(accumulated_input)
411 .accumulated_output_tokens(accumulated_output)
412 .apply()
413 .await?;
414 }
415
416 Ok(())
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::conversation::message::Message;
424 use crate::model::ModelConfig;
425 use crate::providers::base::{Provider, ProviderUsage, Usage};
426 use crate::providers::errors::ProviderError;
427 use crate::scheduler::{ScheduledJob, SchedulerError};
428 use crate::scheduler_trait::SchedulerTrait;
429 use crate::session::Session;
430 use async_trait::async_trait;
431 use chrono::{DateTime, Utc};
432 use rmcp::object;
433 use std::path::PathBuf;
434
435 #[derive(Clone)]
436 struct MockProvider {
437 model_config: ModelConfig,
438 }
439
440 #[async_trait]
441 impl Provider for MockProvider {
442 fn metadata() -> crate::providers::base::ProviderMetadata {
443 crate::providers::base::ProviderMetadata::empty()
444 }
445
446 fn get_name(&self) -> &str {
447 "mock"
448 }
449
450 fn get_model_config(&self) -> ModelConfig {
451 self.model_config.clone()
452 }
453
454 async fn complete_with_model(
455 &self,
456 _model_config: &ModelConfig,
457 _system: &str,
458 _messages: &[Message],
459 _tools: &[Tool],
460 ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
461 Ok((
462 Message::assistant().with_text("ok"),
463 ProviderUsage::new("mock".to_string(), Usage::default()),
464 ))
465 }
466 }
467
468 struct MockScheduler;
470
471 #[async_trait]
472 impl SchedulerTrait for MockScheduler {
473 async fn add_scheduled_job(
474 &self,
475 _job: ScheduledJob,
476 _copy_recipe: bool,
477 ) -> Result<(), SchedulerError> {
478 Ok(())
479 }
480 async fn schedule_recipe(
481 &self,
482 _recipe_path: PathBuf,
483 _cron_schedule: Option<String>,
484 ) -> anyhow::Result<(), SchedulerError> {
485 Ok(())
486 }
487 async fn list_scheduled_jobs(&self) -> Vec<ScheduledJob> {
488 vec![]
489 }
490 async fn remove_scheduled_job(
491 &self,
492 _id: &str,
493 _remove_recipe: bool,
494 ) -> Result<(), SchedulerError> {
495 Ok(())
496 }
497 async fn pause_schedule(&self, _id: &str) -> Result<(), SchedulerError> {
498 Ok(())
499 }
500 async fn unpause_schedule(&self, _id: &str) -> Result<(), SchedulerError> {
501 Ok(())
502 }
503 async fn run_now(&self, _id: &str) -> Result<String, SchedulerError> {
504 Ok("mock-session".to_string())
505 }
506 async fn sessions(
507 &self,
508 _sched_id: &str,
509 _limit: usize,
510 ) -> Result<Vec<(String, Session)>, SchedulerError> {
511 Ok(vec![])
512 }
513 async fn update_schedule(
514 &self,
515 _sched_id: &str,
516 _new_cron: String,
517 ) -> Result<(), SchedulerError> {
518 Ok(())
519 }
520 async fn kill_running_job(&self, _sched_id: &str) -> Result<(), SchedulerError> {
521 Ok(())
522 }
523 async fn get_running_job_info(
524 &self,
525 _sched_id: &str,
526 ) -> Result<Option<(String, DateTime<Utc>)>, SchedulerError> {
527 Ok(None)
528 }
529 }
530
531 #[tokio::test]
532 async fn prepare_tools_sorts_and_includes_frontend_and_list_tools() -> anyhow::Result<()> {
533 let agent = crate::agents::Agent::new();
534
535 agent
537 .set_scheduler(std::sync::Arc::new(MockScheduler))
538 .await;
539
540 let session = SessionManager::create_session(
541 std::path::PathBuf::default(),
542 "test-prepare-tools".to_string(),
543 SessionType::Hidden,
544 )
545 .await?;
546
547 let model_config = ModelConfig::new("test-model").unwrap();
548 let provider = std::sync::Arc::new(MockProvider { model_config });
549 agent.update_provider(provider, &session.id).await?;
550
551 let frontend_tools = vec![
553 Tool::new(
554 "frontend__z_tool".to_string(),
555 "Z tool".to_string(),
556 object!({ "type": "object", "properties": { } }),
557 ),
558 Tool::new(
559 "frontend__a_tool".to_string(),
560 "A tool".to_string(),
561 object!({ "type": "object", "properties": { } }),
562 ),
563 ];
564
565 agent
566 .add_extension(crate::agents::extension::ExtensionConfig::Frontend {
567 name: "frontend".to_string(),
568 description: "desc".to_string(),
569 tools: frontend_tools,
570 instructions: None,
571 bundled: None,
572 available_tools: vec![],
573 })
574 .await
575 .unwrap();
576
577 let working_dir = std::env::current_dir()?;
578 let (tools, _toolshim_tools, _system_prompt) =
579 agent.prepare_tools_and_prompt(&working_dir, None).await?;
580
581 let names: Vec<String> = tools.iter().map(|t| t.name.clone().into_owned()).collect();
583 assert!(names.iter().any(|n| n.starts_with("platform__")));
584 assert!(names.iter().any(|n| n == "frontend__a_tool"));
585 assert!(names.iter().any(|n| n == "frontend__z_tool"));
586
587 let mut sorted = names.clone();
589 sorted.sort();
590 assert_eq!(names, sorted);
591
592 Ok(())
593 }
594}