Skip to main content

aster/agents/
reply_parts.rs

1use anyhow::Result;
2use std::sync::Arc;
3
4use async_stream::try_stream;
5use futures::stream::StreamExt;
6use serde_json::{json, Value};
7use tracing::debug;
8
9use super::super::agents::Agent;
10use crate::conversation::message::{Message, MessageContent, ToolRequest};
11use crate::conversation::Conversation;
12use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage};
13use crate::providers::errors::ProviderError;
14use crate::providers::toolshim::{
15    augment_message_with_tool_calls, convert_tool_messages_to_text,
16    modify_system_prompt_for_tool_json, OllamaInterpreter,
17};
18
19use crate::agents::code_execution_extension::EXTENSION_NAME as CODE_EXECUTION_EXTENSION;
20#[cfg(test)]
21use crate::session::SessionType;
22use crate::session::{SessionManager, SessionStore, TokenStatsUpdate};
23use rmcp::model::Tool;
24
25fn coerce_value(s: &str, schema: &Value) -> Value {
26    let type_str = schema.get("type");
27
28    match type_str {
29        Some(Value::String(t)) => match t.as_str() {
30            "number" | "integer" => try_coerce_number(s),
31            "boolean" => try_coerce_boolean(s),
32            _ => Value::String(s.to_string()),
33        },
34        Some(Value::Array(types)) => {
35            // Try each type in order
36            for t in types {
37                if let Value::String(type_name) = t {
38                    match type_name.as_str() {
39                        "number" | "integer" if s.parse::<f64>().is_ok() => {
40                            return try_coerce_number(s)
41                        }
42                        "boolean" if matches!(s.to_lowercase().as_str(), "true" | "false") => {
43                            return try_coerce_boolean(s)
44                        }
45                        _ => continue,
46                    }
47                }
48            }
49            Value::String(s.to_string())
50        }
51        _ => Value::String(s.to_string()),
52    }
53}
54
55fn try_coerce_number(s: &str) -> Value {
56    if let Ok(n) = s.parse::<f64>() {
57        if n.fract() == 0.0 && n >= i64::MIN as f64 && n <= i64::MAX as f64 {
58            json!(n as i64)
59        } else {
60            json!(n)
61        }
62    } else {
63        Value::String(s.to_string())
64    }
65}
66
67fn try_coerce_boolean(s: &str) -> Value {
68    match s.to_lowercase().as_str() {
69        "true" => json!(true),
70        "false" => json!(false),
71        _ => Value::String(s.to_string()),
72    }
73}
74
75fn coerce_tool_arguments(
76    arguments: Option<serde_json::Map<String, Value>>,
77    tool_schema: &Value,
78) -> Option<serde_json::Map<String, Value>> {
79    let args = arguments?;
80
81    let properties = tool_schema.get("properties").and_then(|p| p.as_object())?;
82
83    let mut coerced = serde_json::Map::new();
84
85    for (key, value) in args.iter() {
86        let coerced_value =
87            if let (Value::String(s), Some(prop_schema)) = (value, properties.get(key)) {
88                coerce_value(s, prop_schema)
89            } else {
90                value.clone()
91            };
92        coerced.insert(key.clone(), coerced_value);
93    }
94
95    Some(coerced)
96}
97
98async fn toolshim_postprocess(
99    response: Message,
100    toolshim_tools: &[Tool],
101) -> Result<Message, ProviderError> {
102    let interpreter = OllamaInterpreter::new().map_err(|e| {
103        ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e))
104    })?;
105
106    augment_message_with_tool_calls(&interpreter, response, toolshim_tools)
107        .await
108        .map_err(|e| ProviderError::ExecutionError(format!("Failed to augment message: {}", e)))
109}
110
111impl Agent {
112    pub async fn prepare_tools_and_prompt(
113        &self,
114        working_dir: &std::path::Path,
115        session_prompt: Option<&str>,
116    ) -> Result<(Vec<Tool>, Vec<Tool>, String)> {
117        // Get tools from extension manager
118        let mut tools = self.list_tools(None).await;
119
120        // Add frontend tools
121        let frontend_tools = self.frontend_tools.lock().await;
122        for frontend_tool in frontend_tools.values() {
123            tools.push(frontend_tool.tool.clone());
124        }
125
126        let code_execution_active = self
127            .extension_manager
128            .is_extension_enabled(CODE_EXECUTION_EXTENSION)
129            .await;
130        if code_execution_active {
131            let code_exec_prefix = format!("{CODE_EXECUTION_EXTENSION}__");
132            tools.retain(|tool| tool.name.starts_with(&code_exec_prefix));
133        }
134
135        // Stable tool ordering is important for multi session prompt caching.
136        tools.sort_by(|a, b| a.name.cmp(&b.name));
137
138        // Prepare system prompt
139        let extensions_info = self.extension_manager.get_extensions_info().await;
140        let (extension_count, tool_count) =
141            self.extension_manager.get_extension_and_tool_counts().await;
142
143        // Get model name from provider
144        let provider = self.provider().await?;
145        let model_config = provider.get_model_config();
146
147        let prompt_manager = self.prompt_manager.lock().await;
148        let mut system_prompt = prompt_manager
149            .builder()
150            .with_extensions(extensions_info.into_iter())
151            .with_frontend_instructions(self.frontend_instructions.lock().await.clone())
152            .with_extension_and_tool_counts(extension_count, tool_count)
153            .with_code_execution_mode(code_execution_active)
154            .with_hints(working_dir)
155            .with_enable_subagents(self.subagents_enabled().await)
156            .with_session_prompt(session_prompt.map(|s| s.to_string()))
157            .build();
158
159        // Handle toolshim if enabled
160        let mut toolshim_tools = vec![];
161        if model_config.toolshim {
162            // If tool interpretation is enabled, modify the system prompt
163            system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
164            // Make a copy of tools before emptying
165            toolshim_tools = tools.clone();
166            // Empty the tools vector for provider completion
167            tools = vec![];
168        }
169
170        Ok((tools, toolshim_tools, system_prompt))
171    }
172
173    /// Stream a response from the LLM provider.
174    /// Handles toolshim transformations if needed
175    pub(crate) async fn stream_response_from_provider(
176        provider: Arc<dyn Provider>,
177        system_prompt: &str,
178        messages: &[Message],
179        tools: &[Tool],
180        toolshim_tools: &[Tool],
181    ) -> Result<MessageStream, ProviderError> {
182        let config = provider.get_model_config();
183
184        // Convert tool messages to text if toolshim is enabled
185        let messages_for_provider = if config.toolshim {
186            convert_tool_messages_to_text(messages)
187        } else {
188            Conversation::new_unvalidated(messages.to_vec())
189        };
190
191        // Clone owned data to move into the async stream
192        let system_prompt = system_prompt.to_owned();
193        let tools = tools.to_owned();
194        let toolshim_tools = toolshim_tools.to_owned();
195        let provider = provider.clone();
196
197        // Capture errors during stream creation and return them as part of the stream
198        // so they can be handled by the existing error handling logic in the agent
199        let stream_result = if provider.supports_streaming() {
200            debug!("WAITING_LLM_STREAM_START");
201            let result = provider
202                .stream(
203                    system_prompt.as_str(),
204                    messages_for_provider.messages(),
205                    &tools,
206                )
207                .await;
208            debug!("WAITING_LLM_STREAM_END");
209            result
210        } else {
211            debug!("WAITING_LLM_START");
212            let complete_result = provider
213                .complete(
214                    system_prompt.as_str(),
215                    messages_for_provider.messages(),
216                    &tools,
217                )
218                .await;
219            debug!("WAITING_LLM_END");
220
221            match complete_result {
222                Ok((message, usage)) => Ok(stream_from_single_message(message, usage)),
223                Err(e) => Err(e),
224            }
225        };
226
227        // If there was an error creating the stream, return a stream that yields that error
228        let mut stream = match stream_result {
229            Ok(s) => s,
230            Err(e) => {
231                // Return a stream that immediately yields the error
232                // This allows the error to be caught by existing error handling in agent.rs
233                return Ok(Box::pin(try_stream! {
234                    yield Err(e)?;
235                }));
236            }
237        };
238
239        Ok(Box::pin(try_stream! {
240            while let Some(Ok((mut message, usage))) = stream.next().await {
241                // Store the model information in the global store
242                if let Some(usage) = usage.as_ref() {
243                    crate::providers::base::set_current_model(&usage.model);
244                }
245
246                // Post-process / structure the response only if tool interpretation is enabled
247                if message.is_some() && config.toolshim {
248                    message = Some(toolshim_postprocess(message.unwrap(), &toolshim_tools).await?);
249                }
250
251                yield (message, usage);
252            }
253        }))
254    }
255
256    /// Categorize tool requests from the response into different types
257    /// Returns:
258    /// - frontend_requests: Tool requests that should be handled by the frontend
259    /// - other_requests: All other tool requests (including requests to enable extensions)
260    /// - filtered_message: The original message with frontend tool requests removed
261    pub(crate) async fn categorize_tool_requests(
262        &self,
263        response: &Message,
264        tools: &[Tool],
265    ) -> (Vec<ToolRequest>, Vec<ToolRequest>, Message) {
266        // First collect all tool requests with coercion applied
267        let tool_requests: Vec<ToolRequest> = response
268            .content
269            .iter()
270            .filter_map(|content| {
271                if let MessageContent::ToolRequest(req) = content {
272                    let mut coerced_req = req.clone();
273
274                    if let Ok(ref mut tool_call) = coerced_req.tool_call {
275                        if let Some(tool) = tools.iter().find(|t| t.name == tool_call.name) {
276                            let schema_value = Value::Object(tool.input_schema.as_ref().clone());
277                            tool_call.arguments =
278                                coerce_tool_arguments(tool_call.arguments.clone(), &schema_value);
279
280                            if let Some(ref meta) = tool.meta {
281                                coerced_req.tool_meta = serde_json::to_value(meta).ok();
282                            }
283                        }
284                    }
285
286                    Some(coerced_req)
287                } else {
288                    None
289                }
290            })
291            .collect();
292
293        // Create a filtered message with frontend tool requests removed
294        let mut filtered_content = Vec::new();
295        let mut tool_request_index = 0;
296
297        for content in &response.content {
298            match content {
299                MessageContent::ToolRequest(_) => {
300                    if tool_request_index < tool_requests.len() {
301                        let coerced_req = &tool_requests[tool_request_index];
302                        tool_request_index += 1;
303
304                        let should_include = if let Ok(tool_call) = &coerced_req.tool_call {
305                            !self.is_frontend_tool(&tool_call.name).await
306                        } else {
307                            true
308                        };
309
310                        if should_include {
311                            filtered_content.push(MessageContent::ToolRequest(coerced_req.clone()));
312                        }
313                    }
314                }
315                _ => {
316                    filtered_content.push(content.clone());
317                }
318            }
319        }
320
321        let mut filtered_message =
322            Message::new(response.role.clone(), response.created, filtered_content);
323
324        // Preserve the ID if it exists
325        if let Some(id) = response.id.clone() {
326            filtered_message = filtered_message.with_id(id);
327        }
328
329        // Categorize tool requests
330        let mut frontend_requests = Vec::new();
331        let mut other_requests = Vec::new();
332
333        for request in tool_requests {
334            if let Ok(tool_call) = &request.tool_call {
335                if self.is_frontend_tool(&tool_call.name).await {
336                    frontend_requests.push(request);
337                } else {
338                    other_requests.push(request);
339                }
340            } else {
341                // If there's an error in the tool call, add it to other_requests
342                other_requests.push(request);
343            }
344        }
345
346        (frontend_requests, other_requests, filtered_message)
347    }
348
349    pub(crate) async fn update_session_metrics(
350        session_config: &crate::agents::types::SessionConfig,
351        usage: &ProviderUsage,
352        is_compaction_usage: bool,
353        session_store: Option<&Arc<dyn SessionStore>>,
354    ) -> Result<()> {
355        let session_id = session_config.id.as_str();
356        let session = if let Some(store) = session_store {
357            store.get_session(session_id, false).await?
358        } else {
359            SessionManager::get_session(session_id, false).await?
360        };
361
362        let accumulate = |a: Option<i32>, b: Option<i32>| -> Option<i32> {
363            match (a, b) {
364                (Some(x), Some(y)) => Some(x + y),
365                _ => a.or(b),
366            }
367        };
368
369        let accumulated_total =
370            accumulate(session.accumulated_total_tokens, usage.usage.total_tokens);
371        let accumulated_input =
372            accumulate(session.accumulated_input_tokens, usage.usage.input_tokens);
373        let accumulated_output =
374            accumulate(session.accumulated_output_tokens, usage.usage.output_tokens);
375
376        let (current_total, current_input, current_output) = if is_compaction_usage {
377            // After compaction: summary output becomes new input context
378            let new_input = usage.usage.output_tokens;
379            (new_input, new_input, None)
380        } else {
381            (
382                usage.usage.total_tokens,
383                usage.usage.input_tokens,
384                usage.usage.output_tokens,
385            )
386        };
387
388        if let Some(store) = session_store {
389            store
390                .update_token_stats(
391                    session_id,
392                    TokenStatsUpdate {
393                        schedule_id: session_config.schedule_id.clone(),
394                        total_tokens: current_total,
395                        input_tokens: current_input,
396                        output_tokens: current_output,
397                        accumulated_total,
398                        accumulated_input,
399                        accumulated_output,
400                    },
401                )
402                .await?;
403        } else {
404            SessionManager::update_session(session_id)
405                .schedule_id(session_config.schedule_id.clone())
406                .total_tokens(current_total)
407                .input_tokens(current_input)
408                .output_tokens(current_output)
409                .accumulated_total_tokens(accumulated_total)
410                .accumulated_input_tokens(accumulated_input)
411                .accumulated_output_tokens(accumulated_output)
412                .apply()
413                .await?;
414        }
415
416        Ok(())
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::conversation::message::Message;
424    use crate::model::ModelConfig;
425    use crate::providers::base::{Provider, ProviderUsage, Usage};
426    use crate::providers::errors::ProviderError;
427    use crate::scheduler::{ScheduledJob, SchedulerError};
428    use crate::scheduler_trait::SchedulerTrait;
429    use crate::session::Session;
430    use async_trait::async_trait;
431    use chrono::{DateTime, Utc};
432    use rmcp::object;
433    use std::path::PathBuf;
434
435    #[derive(Clone)]
436    struct MockProvider {
437        model_config: ModelConfig,
438    }
439
440    #[async_trait]
441    impl Provider for MockProvider {
442        fn metadata() -> crate::providers::base::ProviderMetadata {
443            crate::providers::base::ProviderMetadata::empty()
444        }
445
446        fn get_name(&self) -> &str {
447            "mock"
448        }
449
450        fn get_model_config(&self) -> ModelConfig {
451            self.model_config.clone()
452        }
453
454        async fn complete_with_model(
455            &self,
456            _model_config: &ModelConfig,
457            _system: &str,
458            _messages: &[Message],
459            _tools: &[Tool],
460        ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
461            Ok((
462                Message::assistant().with_text("ok"),
463                ProviderUsage::new("mock".to_string(), Usage::default()),
464            ))
465        }
466    }
467
468    /// Mock scheduler for testing
469    struct MockScheduler;
470
471    #[async_trait]
472    impl SchedulerTrait for MockScheduler {
473        async fn add_scheduled_job(
474            &self,
475            _job: ScheduledJob,
476            _copy_recipe: bool,
477        ) -> Result<(), SchedulerError> {
478            Ok(())
479        }
480        async fn schedule_recipe(
481            &self,
482            _recipe_path: PathBuf,
483            _cron_schedule: Option<String>,
484        ) -> anyhow::Result<(), SchedulerError> {
485            Ok(())
486        }
487        async fn list_scheduled_jobs(&self) -> Vec<ScheduledJob> {
488            vec![]
489        }
490        async fn remove_scheduled_job(
491            &self,
492            _id: &str,
493            _remove_recipe: bool,
494        ) -> Result<(), SchedulerError> {
495            Ok(())
496        }
497        async fn pause_schedule(&self, _id: &str) -> Result<(), SchedulerError> {
498            Ok(())
499        }
500        async fn unpause_schedule(&self, _id: &str) -> Result<(), SchedulerError> {
501            Ok(())
502        }
503        async fn run_now(&self, _id: &str) -> Result<String, SchedulerError> {
504            Ok("mock-session".to_string())
505        }
506        async fn sessions(
507            &self,
508            _sched_id: &str,
509            _limit: usize,
510        ) -> Result<Vec<(String, Session)>, SchedulerError> {
511            Ok(vec![])
512        }
513        async fn update_schedule(
514            &self,
515            _sched_id: &str,
516            _new_cron: String,
517        ) -> Result<(), SchedulerError> {
518            Ok(())
519        }
520        async fn kill_running_job(&self, _sched_id: &str) -> Result<(), SchedulerError> {
521            Ok(())
522        }
523        async fn get_running_job_info(
524            &self,
525            _sched_id: &str,
526        ) -> Result<Option<(String, DateTime<Utc>)>, SchedulerError> {
527            Ok(None)
528        }
529    }
530
531    #[tokio::test]
532    async fn prepare_tools_sorts_and_includes_frontend_and_list_tools() -> anyhow::Result<()> {
533        let agent = crate::agents::Agent::new();
534
535        // 设置 mock scheduler 以便 platform__ 工具可用
536        agent
537            .set_scheduler(std::sync::Arc::new(MockScheduler))
538            .await;
539
540        let session = SessionManager::create_session(
541            std::path::PathBuf::default(),
542            "test-prepare-tools".to_string(),
543            SessionType::Hidden,
544        )
545        .await?;
546
547        let model_config = ModelConfig::new("test-model").unwrap();
548        let provider = std::sync::Arc::new(MockProvider { model_config });
549        agent.update_provider(provider, &session.id).await?;
550
551        // Add unsorted frontend tools
552        let frontend_tools = vec![
553            Tool::new(
554                "frontend__z_tool".to_string(),
555                "Z tool".to_string(),
556                object!({ "type": "object", "properties": { } }),
557            ),
558            Tool::new(
559                "frontend__a_tool".to_string(),
560                "A tool".to_string(),
561                object!({ "type": "object", "properties": { } }),
562            ),
563        ];
564
565        agent
566            .add_extension(crate::agents::extension::ExtensionConfig::Frontend {
567                name: "frontend".to_string(),
568                description: "desc".to_string(),
569                tools: frontend_tools,
570                instructions: None,
571                bundled: None,
572                available_tools: vec![],
573            })
574            .await
575            .unwrap();
576
577        let working_dir = std::env::current_dir()?;
578        let (tools, _toolshim_tools, _system_prompt) =
579            agent.prepare_tools_and_prompt(&working_dir, None).await?;
580
581        // Ensure both platform and frontend tools are present
582        let names: Vec<String> = tools.iter().map(|t| t.name.clone().into_owned()).collect();
583        assert!(names.iter().any(|n| n.starts_with("platform__")));
584        assert!(names.iter().any(|n| n == "frontend__a_tool"));
585        assert!(names.iter().any(|n| n == "frontend__z_tool"));
586
587        // Verify the names are sorted ascending
588        let mut sorted = names.clone();
589        sorted.sort();
590        assert_eq!(names, sorted);
591
592        Ok(())
593    }
594}