Skip to main content

aether_cli/
runtime.rs

1use crate::error::CliError;
2use aether_core::agent_spec::{AgentSpec, McpConfigSource};
3use aether_core::core::{AgentBuilder, AgentHandle, Prompt};
4use aether_core::events::{AgentMessage, UserMessage};
5use aether_core::mcp::McpBuilder;
6use aether_core::mcp::McpSpawnResult;
7use aether_core::mcp::mcp;
8use aether_core::mcp::run_mcp_task::McpCommand;
9use llm::{ChatMessage, LlmModel, ToolDefinition};
10use mcp_servers::McpBuilderExt;
11use mcp_utils::client::oauth::OAuthHandler;
12use mcp_utils::client::{McpClientEvent, McpServerConfig};
13use mcp_utils::status::McpServerStatusEntry;
14use std::path::{Path, PathBuf};
15use tokio::sync::mpsc::{Receiver, Sender};
16use tokio::task::JoinHandle;
17use tracing::debug;
18
19pub struct RuntimeBuilder {
20    cwd: PathBuf,
21    spec: AgentSpec,
22    mcp_config_sources: Vec<McpConfigSource>,
23    extra_mcp_servers: Vec<McpServerConfig>,
24    oauth_applicator: Option<Box<dyn FnOnce(McpBuilder) -> McpBuilder + Send>>,
25    prompt_cache_key: Option<String>,
26}
27
28pub struct Runtime {
29    pub agent_tx: Sender<UserMessage>,
30    pub agent_rx: Receiver<AgentMessage>,
31    pub agent_handle: AgentHandle,
32    pub mcp_tx: Sender<McpCommand>,
33    pub event_rx: Receiver<McpClientEvent>,
34    pub server_statuses: Vec<McpServerStatusEntry>,
35    pub mcp_handle: JoinHandle<()>,
36}
37
38pub struct PromptInfo {
39    pub spec: AgentSpec,
40    pub tool_definitions: Vec<ToolDefinition>,
41}
42
43impl RuntimeBuilder {
44    pub fn new(cwd: &Path, model: &str) -> Result<Self, CliError> {
45        let cwd = cwd.canonicalize().map_err(CliError::IoError)?;
46        let parsed_model: LlmModel = model.parse().map_err(CliError::ModelError)?;
47        let spec = AgentSpec::default_spec(&parsed_model, None, Vec::new());
48
49        Ok(Self {
50            cwd,
51            spec,
52            mcp_config_sources: Vec::new(),
53            extra_mcp_servers: Vec::new(),
54            oauth_applicator: None,
55            prompt_cache_key: None,
56        })
57    }
58
59    pub fn from_spec(cwd: PathBuf, spec: AgentSpec) -> Self {
60        Self {
61            cwd,
62            spec,
63            mcp_config_sources: Vec::new(),
64            extra_mcp_servers: Vec::new(),
65            oauth_applicator: None,
66            prompt_cache_key: None,
67        }
68    }
69
70    pub fn prompt_cache_key(mut self, key: String) -> Self {
71        self.prompt_cache_key = Some(key);
72        self
73    }
74
75    /// Set MCP config source overrides. When non-empty, these completely
76    /// replace any sources resolved from the agent's `AgentSpec`.
77    pub fn mcp_sources(mut self, sources: Vec<McpConfigSource>) -> Self {
78        self.mcp_config_sources = sources;
79        self
80    }
81
82    pub fn extra_servers(mut self, servers: Vec<McpServerConfig>) -> Self {
83        self.extra_mcp_servers = servers;
84        self
85    }
86
87    pub fn oauth_handler<H: OAuthHandler + 'static>(mut self, handler: H) -> Self {
88        self.oauth_applicator = Some(Box::new(|builder| builder.with_oauth_handler(handler)));
89        self
90    }
91
92    pub async fn build(
93        self,
94        custom_prompt: Option<Prompt>,
95        messages: Option<Vec<ChatMessage>>,
96    ) -> Result<Runtime, CliError> {
97        let prompt_cache_key = self.prompt_cache_key.clone();
98        let mcp = self.spawn_mcp().await?;
99
100        let filtered_tools = mcp.spec.tools.apply(mcp.tool_definitions);
101        let mut agent_builder = AgentBuilder::from_spec(&mcp.spec, vec![])
102            .await
103            .map_err(|e| CliError::AgentError(e.to_string()))?
104            .tools(mcp.mcp_tx.clone(), filtered_tools);
105
106        if let Some(key) = prompt_cache_key {
107            agent_builder = agent_builder.prompt_cache_key(key);
108        }
109
110        if let Some(prompt) = custom_prompt {
111            agent_builder = agent_builder.system_prompt(prompt);
112        }
113
114        if let Some(msgs) = messages {
115            agent_builder = agent_builder.messages(msgs);
116        }
117
118        let (agent_tx, agent_rx, agent_handle) =
119            agent_builder.spawn().await.map_err(|e| CliError::AgentError(e.to_string()))?;
120
121        Ok(Runtime {
122            agent_tx,
123            agent_rx,
124            agent_handle,
125            mcp_tx: mcp.mcp_tx,
126            event_rx: mcp.event_rx,
127            server_statuses: mcp.server_statuses,
128            mcp_handle: mcp.mcp_handle,
129        })
130    }
131
132    pub async fn build_prompt_info(self) -> Result<PromptInfo, CliError> {
133        let mcp = self.spawn_mcp().await?;
134        let filtered_tools = mcp.spec.tools.apply(mcp.tool_definitions);
135        Ok(PromptInfo { spec: mcp.spec, tool_definitions: filtered_tools })
136    }
137
138    async fn spawn_mcp(self) -> Result<McpParts, CliError> {
139        let mut builder = mcp().with_builtin_servers(self.cwd.clone(), &self.cwd);
140
141        if !self.extra_mcp_servers.is_empty() {
142            builder = builder.with_servers(self.extra_mcp_servers);
143        }
144
145        if let Some(apply_oauth) = self.oauth_applicator {
146            builder = apply_oauth(builder);
147        }
148
149        let mcp_config_sources: Vec<McpConfigSource> = if self.mcp_config_sources.is_empty() {
150            self.spec.mcp_config_sources.clone()
151        } else {
152            self.mcp_config_sources
153        };
154
155        if !mcp_config_sources.is_empty() {
156            debug!("Loading MCP configs from: {:?}", mcp_config_sources);
157            builder = builder
158                .from_mcp_config_sources(&mcp_config_sources)
159                .await
160                .map_err(|e| CliError::McpError(e.to_string()))?;
161        }
162
163        let McpSpawnResult {
164            tool_definitions,
165            instructions,
166            server_statuses,
167            command_tx: mcp_tx,
168            event_rx,
169            handle: mcp_handle,
170        } = builder.spawn().await.map_err(|e| CliError::McpError(e.to_string()))?;
171
172        let mut spec = self.spec;
173        spec.prompts.push(Prompt::mcp_instructions(instructions));
174
175        Ok(McpParts { spec, tool_definitions, mcp_tx, event_rx, server_statuses, mcp_handle })
176    }
177}
178
179struct McpParts {
180    spec: AgentSpec,
181    tool_definitions: Vec<ToolDefinition>,
182    mcp_tx: Sender<McpCommand>,
183    event_rx: Receiver<McpClientEvent>,
184    server_statuses: Vec<McpServerStatusEntry>,
185    mcp_handle: JoinHandle<()>,
186}