Skip to main content

aether_cli/
runtime.rs

1use crate::error::CliError;
2use aether_auth::OAuthCredentialStorage;
3use aether_core::agent_spec::{AgentSpec, McpConfigSource};
4use aether_core::core::{AgentBuilder, AgentHandle, Prompt};
5use aether_core::events::{AgentMessage, UserMessage};
6use aether_core::mcp::McpBuilder;
7use aether_core::mcp::McpSpawnResult;
8use aether_core::mcp::mcp;
9use aether_core::mcp::run_mcp_task::McpCommand;
10use llm::{ChatMessage, LlmModel, ToolDefinition};
11use mcp_servers::McpBuilderExt;
12use mcp_utils::client::{McpClientEvent, McpServer, OAuthHandlerFactory};
13use mcp_utils::status::McpServerStatusEntry;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tokio::sync::mpsc::{Receiver, Sender};
17use tokio::task::JoinHandle;
18use tracing::debug;
19
20pub struct RuntimeBuilder {
21    cwd: PathBuf,
22    spec: AgentSpec,
23    mcp_config_sources: Vec<McpConfigSource>,
24    extra_mcp_servers: Vec<McpServer>,
25    oauth_applicator: Option<Box<dyn FnOnce(McpBuilder) -> McpBuilder + Send>>,
26    oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
27    prompt_cache_key: Option<String>,
28}
29
30pub struct Runtime {
31    pub agent_tx: Sender<UserMessage>,
32    pub agent_rx: Receiver<AgentMessage>,
33    pub agent_handle: AgentHandle,
34    pub mcp_tx: Sender<McpCommand>,
35    pub event_rx: Receiver<McpClientEvent>,
36    pub server_statuses: Vec<McpServerStatusEntry>,
37    pub mcp_handle: JoinHandle<()>,
38}
39
40pub struct PromptInfo {
41    pub spec: AgentSpec,
42    pub tool_definitions: Vec<ToolDefinition>,
43}
44
45impl RuntimeBuilder {
46    pub fn new(cwd: &Path, model: &str) -> Result<Self, CliError> {
47        let cwd = cwd.canonicalize().map_err(CliError::IoError)?;
48        let parsed_model: LlmModel = model.parse().map_err(CliError::ModelError)?;
49        let spec = AgentSpec::default_spec(&parsed_model, None, Vec::new());
50
51        Ok(Self {
52            cwd,
53            spec,
54            mcp_config_sources: Vec::new(),
55            extra_mcp_servers: Vec::new(),
56            oauth_applicator: None,
57            oauth_credential_store: None,
58            prompt_cache_key: None,
59        })
60    }
61
62    pub fn from_spec(cwd: PathBuf, spec: AgentSpec) -> Self {
63        Self {
64            cwd,
65            spec,
66            mcp_config_sources: Vec::new(),
67            extra_mcp_servers: Vec::new(),
68            oauth_applicator: None,
69            oauth_credential_store: None,
70            prompt_cache_key: None,
71        }
72    }
73
74    pub fn prompt_cache_key(mut self, key: String) -> Self {
75        self.prompt_cache_key = Some(key);
76        self
77    }
78
79    /// Set MCP config source overrides. When non-empty, these completely
80    /// replace any sources resolved from the agent's `AgentSpec`.
81    pub fn mcp_sources(mut self, sources: Vec<McpConfigSource>) -> Self {
82        self.mcp_config_sources = sources;
83        self
84    }
85
86    pub fn extra_servers(mut self, servers: Vec<McpServer>) -> Self {
87        self.extra_mcp_servers = servers;
88        self
89    }
90
91    pub fn oauth_handler_factory(mut self, factory: OAuthHandlerFactory) -> Self {
92        self.oauth_applicator = Some(Box::new(|builder| builder.with_oauth_handler_factory(factory)));
93        self
94    }
95
96    pub fn oauth_credential_store(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
97        self.oauth_credential_store = Some(store);
98        self
99    }
100
101    pub async fn build(
102        self,
103        custom_prompt: Option<Prompt>,
104        messages: Option<Vec<ChatMessage>>,
105    ) -> Result<Runtime, CliError> {
106        let prompt_cache_key = self.prompt_cache_key.clone();
107        let oauth_credential_store = self.oauth_credential_store.clone();
108        let mcp = self.spawn_mcp().await?;
109
110        let filtered_tools = mcp.spec.tools.apply(mcp.tool_definitions);
111        let mut agent_builder = AgentBuilder::from_spec(&mcp.spec, vec![], oauth_credential_store)
112            .await
113            .map_err(|e| CliError::AgentError(e.to_string()))?
114            .tools(mcp.mcp_tx.clone(), filtered_tools);
115
116        if let Some(key) = prompt_cache_key {
117            agent_builder = agent_builder.prompt_cache_key(key);
118        }
119
120        if let Some(prompt) = custom_prompt {
121            agent_builder = agent_builder.system_prompt(prompt);
122        }
123
124        if let Some(msgs) = messages {
125            agent_builder = agent_builder.messages(msgs);
126        }
127
128        let (agent_tx, agent_rx, agent_handle) =
129            agent_builder.spawn().await.map_err(|e| CliError::AgentError(e.to_string()))?;
130
131        Ok(Runtime {
132            agent_tx,
133            agent_rx,
134            agent_handle,
135            mcp_tx: mcp.mcp_tx,
136            event_rx: mcp.event_rx,
137            server_statuses: mcp.server_statuses,
138            mcp_handle: mcp.mcp_handle,
139        })
140    }
141
142    pub async fn build_prompt_info(self) -> Result<PromptInfo, CliError> {
143        let mcp = self.spawn_mcp().await?;
144        let filtered_tools = mcp.spec.tools.apply(mcp.tool_definitions);
145        Ok(PromptInfo { spec: mcp.spec, tool_definitions: filtered_tools })
146    }
147
148    async fn spawn_mcp(self) -> Result<McpParts, CliError> {
149        let mut builder = mcp().with_builtin_servers(self.cwd.clone(), &self.cwd);
150
151        if !self.extra_mcp_servers.is_empty() {
152            builder = builder.with_servers(self.extra_mcp_servers);
153        }
154
155        if let Some(apply_oauth) = self.oauth_applicator {
156            builder = apply_oauth(builder);
157        }
158
159        if let Some(store) = self.oauth_credential_store {
160            builder = builder.with_oauth_credential_store(store);
161        }
162
163        let mcp_config_sources: Vec<McpConfigSource> = if self.mcp_config_sources.is_empty() {
164            self.spec.mcp_config_sources.clone()
165        } else {
166            self.mcp_config_sources
167        };
168
169        if !mcp_config_sources.is_empty() {
170            debug!("Loading MCP configs from: {:?}", mcp_config_sources);
171            builder = builder
172                .from_mcp_config_sources(&mcp_config_sources)
173                .await
174                .map_err(|e| CliError::McpError(e.to_string()))?;
175        }
176
177        let McpSpawnResult {
178            tool_definitions,
179            instructions,
180            server_statuses,
181            command_tx: mcp_tx,
182            event_rx,
183            handle: mcp_handle,
184        } = builder.spawn().await.map_err(|e| CliError::McpError(e.to_string()))?;
185
186        let mut spec = self.spec;
187        spec.prompts.push(Prompt::mcp_instructions(instructions));
188
189        Ok(McpParts { spec, tool_definitions, mcp_tx, event_rx, server_statuses, mcp_handle })
190    }
191}
192
193struct McpParts {
194    spec: AgentSpec,
195    tool_definitions: Vec<ToolDefinition>,
196    mcp_tx: Sender<McpCommand>,
197    event_rx: Receiver<McpClientEvent>,
198    server_statuses: Vec<McpServerStatusEntry>,
199    mcp_handle: JoinHandle<()>,
200}