Skip to main content

aster_cli/commands/
acp.rs

1use anyhow::Result;
2use aster::agents::extension::{Envs, PlatformExtensionContext, PLATFORM_EXTENSIONS};
3use aster::agents::{Agent, ExtensionConfig, SessionConfig};
4use aster::config::{get_all_extensions, Config};
5use aster::conversation::message::{ActionRequiredData, Message, MessageContent};
6use aster::conversation::Conversation;
7use aster::mcp_utils::ToolResult;
8use aster::permission::permission_confirmation::PrincipalType;
9use aster::permission::{Permission, PermissionConfirmation};
10use aster::providers::create;
11use aster::session::session_manager::SessionType;
12use aster::session::SessionManager;
13use rmcp::model::{CallToolResult, RawContent, ResourceContents, Role};
14use sacp::schema::{
15    AgentCapabilities, AuthenticateRequest, AuthenticateResponse, BlobResourceContents,
16    CancelNotification, Content, ContentBlock, ContentChunk, EmbeddedResource,
17    EmbeddedResourceResource, ImageContent, InitializeRequest, InitializeResponse,
18    LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, NewSessionRequest,
19    NewSessionResponse, PermissionOption, PermissionOptionId, PermissionOptionKind,
20    PromptCapabilities, PromptRequest, PromptResponse, RequestPermissionOutcome,
21    RequestPermissionRequest, ResourceLink, SessionId, SessionNotification, SessionUpdate,
22    StopReason, TextContent, TextResourceContents, ToolCall, ToolCallContent, ToolCallId,
23    ToolCallLocation, ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, ToolKind,
24};
25use sacp::{AgentToClient, ByteStreams, Handled, JrConnectionCx, JrMessageHandler, MessageCx};
26use std::collections::{HashMap, HashSet};
27use std::fs;
28use std::sync::Arc;
29use tokio::sync::Mutex;
30use tokio::task::JoinSet;
31use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _};
32use tokio_util::sync::CancellationToken;
33use tracing::{debug, error, info, warn};
34use url::Url;
35
36struct AsterAcpSession {
37    messages: Conversation,
38    tool_requests: HashMap<String, aster::conversation::message::ToolRequest>,
39    cancel_token: Option<CancellationToken>,
40}
41
42struct AsterAcpAgent {
43    sessions: Arc<Mutex<HashMap<String, AsterAcpSession>>>,
44    agent: Arc<Agent>,
45}
46
47fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result<ExtensionConfig, String> {
48    match mcp_server {
49        McpServer::Stdio(stdio) => Ok(ExtensionConfig::Stdio {
50            name: stdio.name,
51            description: String::new(),
52            cmd: stdio.command.to_string_lossy().to_string(),
53            args: stdio.args,
54            envs: Envs::new(stdio.env.into_iter().map(|e| (e.name, e.value)).collect()),
55            env_keys: vec![],
56            timeout: None,
57            bundled: Some(false),
58            available_tools: vec![],
59        }),
60        McpServer::Http(http) => Ok(ExtensionConfig::StreamableHttp {
61            name: http.name,
62            description: String::new(),
63            uri: http.url,
64            envs: Envs::default(),
65            env_keys: vec![],
66            headers: http
67                .headers
68                .into_iter()
69                .map(|h| (h.name, h.value))
70                .collect(),
71            timeout: None,
72            bundled: Some(false),
73            available_tools: vec![],
74        }),
75        McpServer::Sse(_) => Err("SSE is unsupported, migrate to streamable_http".to_string()),
76        _ => Err("Unknown MCP server type".to_string()),
77    }
78}
79
80fn create_tool_location(path: &str, line: Option<u32>) -> ToolCallLocation {
81    let mut loc = ToolCallLocation::new(path);
82    if let Some(l) = line {
83        loc = loc.line(l);
84    }
85    loc
86}
87
88fn extract_tool_locations(
89    tool_request: &aster::conversation::message::ToolRequest,
90    tool_response: &aster::conversation::message::ToolResponse,
91) -> Vec<ToolCallLocation> {
92    let mut locations = Vec::new();
93
94    // Get the tool call details
95    if let Ok(tool_call) = &tool_request.tool_call {
96        // Only process text_editor tool
97        if tool_call.name != "developer__text_editor" {
98            return locations;
99        }
100
101        // Extract the path from arguments
102        let path_str = tool_call
103            .arguments
104            .as_ref()
105            .and_then(|args| args.get("path"))
106            .and_then(|p| p.as_str());
107
108        if let Some(path_str) = path_str {
109            // Get the command type
110            let command = tool_call
111                .arguments
112                .as_ref()
113                .and_then(|args| args.get("command"))
114                .and_then(|c| c.as_str());
115
116            // Extract line numbers from the response content
117            if let Ok(result) = &tool_response.tool_result {
118                for content in &result.content {
119                    if let RawContent::Text(text_content) = &content.raw {
120                        let text = &text_content.text;
121
122                        // Parse line numbers based on command type and response format
123                        match command {
124                            Some("view") => {
125                                // For view command, look for "lines X-Y" pattern in header
126                                let line = extract_view_line_range(text)
127                                    .map(|range| range.0 as u32)
128                                    .or(Some(1));
129                                locations.push(create_tool_location(path_str, line));
130                            }
131                            Some("str_replace") | Some("insert") => {
132                                // For edits, extract the first line number from the snippet
133                                let line = extract_first_line_number(text)
134                                    .map(|l| l as u32)
135                                    .or(Some(1));
136                                locations.push(create_tool_location(path_str, line));
137                            }
138                            Some("write") => {
139                                // For write, just point to the beginning of the file
140                                locations.push(create_tool_location(path_str, Some(1)));
141                            }
142                            _ => {
143                                // For other commands or unknown, default to line 1
144                                locations.push(create_tool_location(path_str, Some(1)));
145                            }
146                        }
147                        break; // Only process first text content
148                    }
149                }
150            }
151
152            // If we didn't find any locations yet, add a default one
153            if locations.is_empty() {
154                locations.push(create_tool_location(path_str, Some(1)));
155            }
156        }
157    }
158
159    locations
160}
161
162fn extract_view_line_range(text: &str) -> Option<(usize, usize)> {
163    // Pattern: "(lines X-Y)" or "(lines X-end)"
164    let re = regex::Regex::new(r"\(lines (\d+)-(\d+|end)\)").ok()?;
165    if let Some(caps) = re.captures(text) {
166        let start = caps.get(1)?.as_str().parse::<usize>().ok()?;
167        let end = if caps.get(2)?.as_str() == "end" {
168            start // Use start as a reasonable default
169        } else {
170            caps.get(2)?.as_str().parse::<usize>().ok()?
171        };
172        return Some((start, end));
173    }
174    None
175}
176
177fn extract_first_line_number(text: &str) -> Option<usize> {
178    // Pattern: "123: " at the start of a line within a code block
179    let re = regex::Regex::new(r"```[^\n]*\n(\d+):").ok()?;
180    if let Some(caps) = re.captures(text) {
181        return caps.get(1)?.as_str().parse::<usize>().ok();
182    }
183    None
184}
185
186fn read_resource_link(link: ResourceLink) -> Option<String> {
187    let url = Url::parse(&link.uri).ok()?;
188    if url.scheme() == "file" {
189        let path = url.to_file_path().ok()?;
190        let contents = fs::read_to_string(&path).ok()?;
191
192        Some(format!(
193            "\n\n# {}\n```\n{}\n```",
194            path.to_string_lossy(),
195            contents
196        ))
197    } else {
198        None
199    }
200}
201
202fn format_tool_name(tool_name: &str) -> String {
203    if let Some((extension, tool)) = tool_name.split_once("__") {
204        let formatted_extension = extension.replace('_', " ");
205        let formatted_tool = tool.replace('_', " ");
206
207        // Capitalize first letter of each word
208        let capitalize = |s: &str| {
209            s.split_whitespace()
210                .map(|word| {
211                    let mut chars = word.chars();
212                    match chars.next() {
213                        None => String::new(),
214                        Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
215                    }
216                })
217                .collect::<Vec<_>>()
218                .join(" ")
219        };
220
221        format!(
222            "{}: {}",
223            capitalize(&formatted_extension),
224            capitalize(&formatted_tool)
225        )
226    } else {
227        // Fallback for tools without double underscore
228        let formatted = tool_name.replace('_', " ");
229        formatted
230            .split_whitespace()
231            .map(|word| {
232                let mut chars = word.chars();
233                match chars.next() {
234                    None => String::new(),
235                    Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
236                }
237            })
238            .collect::<Vec<_>>()
239            .join(" ")
240    }
241}
242
243async fn add_builtins(agent: &Agent, builtins: Vec<String>) {
244    for builtin in builtins {
245        let config = if PLATFORM_EXTENSIONS.contains_key(builtin.as_str()) {
246            ExtensionConfig::Platform {
247                name: builtin.clone(),
248                bundled: None,
249                description: builtin.clone(),
250                available_tools: Vec::new(),
251            }
252        } else {
253            ExtensionConfig::Builtin {
254                name: builtin.clone(),
255                display_name: None,
256                timeout: None,
257                bundled: None,
258                description: builtin.clone(),
259                available_tools: Vec::new(),
260            }
261        };
262        match agent.add_extension(config).await {
263            Ok(_) => info!(extension = %builtin, "builtin extension loaded"),
264            Err(e) => warn!(extension = %builtin, error = %e, "builtin extension load failed"),
265        }
266    }
267}
268
269impl AsterAcpAgent {
270    async fn new(builtins: Vec<String>) -> Result<Self> {
271        let config = Config::global();
272
273        let provider_name: String = config
274            .get_aster_provider()
275            .map_err(|e| anyhow::anyhow!("No provider configured: {}", e))?;
276
277        let model_name: String = config
278            .get_aster_model()
279            .map_err(|e| anyhow::anyhow!("No model configured: {}", e))?;
280
281        let model_config = aster::model::ModelConfig {
282            model_name: model_name.clone(),
283            context_limit: None,
284            temperature: None,
285            max_tokens: None,
286            toolshim: false,
287            toolshim_model: None,
288            fast_model: None,
289        };
290        let provider = create(&provider_name, model_config).await?;
291
292        let session = SessionManager::create_session(
293            std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
294            "ACP Session".to_string(),
295            SessionType::Hidden,
296        )
297        .await?;
298
299        let agent = Agent::new();
300        agent.update_provider(provider.clone(), &session.id).await?;
301
302        let extensions_to_run: Vec<_> = get_all_extensions()
303            .into_iter()
304            .filter(|ext| ext.enabled)
305            .map(|ext| ext.config)
306            .collect();
307
308        let agent_ptr = Arc::new(agent);
309
310        // ACP loads the same default extensions as CLI
311        agent_ptr
312            .extension_manager
313            .set_context(PlatformExtensionContext {
314                session_id: Some(session.id.clone()),
315                extension_manager: Some(Arc::downgrade(&agent_ptr.extension_manager)),
316            })
317            .await;
318
319        let mut set = JoinSet::new();
320        let mut waiting_on = HashSet::new();
321
322        for extension in extensions_to_run {
323            waiting_on.insert(extension.name());
324            let agent_ptr_clone = agent_ptr.clone();
325            set.spawn(async move {
326                (
327                    extension.name(),
328                    agent_ptr_clone.add_extension(extension.clone()).await,
329                )
330            });
331        }
332
333        while let Some(result) = set.join_next().await {
334            match result {
335                Ok((name, Ok(_))) => {
336                    waiting_on.remove(&name);
337                    info!(extension = %name, "extension loaded");
338                }
339                Ok((name, Err(e))) => {
340                    warn!(extension = %name, error = %e, "extension load failed");
341                    waiting_on.remove(&name);
342                }
343                Err(e) => {
344                    error!(error = %e, "extension task error");
345                }
346            }
347        }
348
349        add_builtins(&agent_ptr, builtins).await;
350
351        Ok(Self {
352            sessions: Arc::new(Mutex::new(HashMap::new())),
353            agent: agent_ptr,
354        })
355    }
356
357    fn convert_acp_prompt_to_message(&self, prompt: Vec<ContentBlock>) -> Message {
358        let mut user_message = Message::user();
359
360        // Process all content blocks from the prompt
361        for block in prompt {
362            match block {
363                ContentBlock::Text(text) => {
364                    user_message = user_message.with_text(&text.text);
365                }
366                ContentBlock::Image(image) => {
367                    // Aster supports images via base64 encoded data
368                    // The ACP ImageContent has data as a String directly
369                    user_message = user_message.with_image(&image.data, &image.mime_type);
370                }
371                ContentBlock::Resource(resource) => {
372                    // Embed resource content as text with context
373                    match &resource.resource {
374                        EmbeddedResourceResource::TextResourceContents(text_resource) => {
375                            let header = format!("--- Resource: {} ---\n", text_resource.uri);
376                            let content = format!("{}{}\n---\n", header, text_resource.text);
377                            user_message = user_message.with_text(&content);
378                        }
379                        _ => {
380                            // Ignore non-text resources for now
381                        }
382                    }
383                }
384                ContentBlock::ResourceLink(link) => {
385                    if let Some(text) = read_resource_link(link) {
386                        user_message = user_message.with_text(text)
387                    }
388                }
389                ContentBlock::Audio(..) => (),
390                _ => (), // Handle any future ContentBlock variants
391            }
392        }
393
394        user_message
395    }
396
397    async fn handle_message_content(
398        &self,
399        content_item: &MessageContent,
400        session_id: &SessionId,
401        session: &mut AsterAcpSession,
402        cx: &JrConnectionCx<AgentToClient>,
403    ) -> Result<(), sacp::Error> {
404        match content_item {
405            MessageContent::Text(text) => {
406                // Stream text to the client
407                cx.send_notification(SessionNotification::new(
408                    session_id.clone(),
409                    SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text(
410                        TextContent::new(&text.text),
411                    ))),
412                ))?;
413            }
414            MessageContent::ToolRequest(tool_request) => {
415                self.handle_tool_request(tool_request, session_id, session, cx)
416                    .await?;
417            }
418            MessageContent::ToolResponse(tool_response) => {
419                self.handle_tool_response(tool_response, session_id, session, cx)
420                    .await?;
421            }
422            MessageContent::Thinking(thinking) => {
423                // Stream thinking/reasoning content as thought chunks
424                cx.send_notification(SessionNotification::new(
425                    session_id.clone(),
426                    SessionUpdate::AgentThoughtChunk(ContentChunk::new(ContentBlock::Text(
427                        TextContent::new(&thinking.thinking),
428                    ))),
429                ))?;
430            }
431            MessageContent::ActionRequired(action_required) => {
432                if let ActionRequiredData::ToolConfirmation {
433                    id,
434                    tool_name,
435                    arguments,
436                    prompt,
437                } = &action_required.data
438                {
439                    self.handle_tool_permission_request(
440                        id.clone(),
441                        tool_name.clone(),
442                        arguments.clone(),
443                        prompt.clone(),
444                        session_id,
445                        cx,
446                    )?;
447                }
448            }
449            _ => {
450                // Ignore other content types for now
451            }
452        }
453        Ok(())
454    }
455
456    async fn handle_tool_request(
457        &self,
458        tool_request: &aster::conversation::message::ToolRequest,
459        session_id: &SessionId,
460        session: &mut AsterAcpSession,
461        cx: &JrConnectionCx<AgentToClient>,
462    ) -> Result<(), sacp::Error> {
463        // Store the tool request for later use in response handling
464        session
465            .tool_requests
466            .insert(tool_request.id.clone(), tool_request.clone());
467
468        // Extract tool name from the ToolCall if successful
469        let tool_name = match &tool_request.tool_call {
470            Ok(tool_call) => tool_call.name.to_string(),
471            Err(_) => "error".to_string(),
472        };
473
474        // Send tool call notification using the provider's tool call ID directly
475        cx.send_notification(SessionNotification::new(
476            session_id.clone(),
477            SessionUpdate::ToolCall(
478                ToolCall::new(
479                    ToolCallId::new(tool_request.id.clone()),
480                    format_tool_name(&tool_name),
481                )
482                .status(ToolCallStatus::Pending),
483            ),
484        ))?;
485
486        Ok(())
487    }
488
489    async fn handle_tool_response(
490        &self,
491        tool_response: &aster::conversation::message::ToolResponse,
492        session_id: &SessionId,
493        session: &mut AsterAcpSession,
494        cx: &JrConnectionCx<AgentToClient>,
495    ) -> Result<(), sacp::Error> {
496        // Determine if the tool call succeeded or failed
497        let status = if tool_response.tool_result.is_ok() {
498            ToolCallStatus::Completed
499        } else {
500            ToolCallStatus::Failed
501        };
502
503        let content = build_tool_call_content(&tool_response.tool_result);
504
505        // Extract locations from the tool request and response
506        let locations = if let Some(tool_request) = session.tool_requests.get(&tool_response.id) {
507            extract_tool_locations(tool_request, tool_response)
508        } else {
509            Vec::new()
510        };
511
512        // Send status update using provider's tool call ID directly
513        let mut fields = ToolCallUpdateFields::new().status(status).content(content);
514        if !locations.is_empty() {
515            fields = fields.locations(locations);
516        }
517        cx.send_notification(SessionNotification::new(
518            session_id.clone(),
519            SessionUpdate::ToolCallUpdate(ToolCallUpdate::new(
520                ToolCallId::new(tool_response.id.clone()),
521                fields,
522            )),
523        ))?;
524
525        Ok(())
526    }
527
528    fn handle_tool_permission_request(
529        &self,
530        request_id: String,
531        tool_name: String,
532        arguments: serde_json::Map<String, serde_json::Value>,
533        prompt: Option<String>,
534        session_id: &SessionId,
535        cx: &JrConnectionCx<AgentToClient>,
536    ) -> Result<(), sacp::Error> {
537        let cx = cx.clone();
538        let agent = self.agent.clone();
539        let session_id = session_id.clone();
540
541        let formatted_name = format_tool_name(&tool_name);
542
543        // Use the request_id (provider's tool call ID) directly
544        let mut fields = ToolCallUpdateFields::new()
545            .title(formatted_name)
546            .kind(ToolKind::default())
547            .status(ToolCallStatus::Pending)
548            .raw_input(serde_json::Value::Object(arguments));
549        if let Some(p) = prompt {
550            fields = fields.content(vec![ToolCallContent::Content(Content::new(
551                ContentBlock::Text(TextContent::new(p)),
552            ))]);
553        }
554        let tool_call_update = ToolCallUpdate::new(ToolCallId::new(request_id.clone()), fields);
555
556        fn option(kind: PermissionOptionKind) -> PermissionOption {
557            let id = serde_json::to_value(kind)
558                .unwrap()
559                .as_str()
560                .unwrap()
561                .to_string();
562            PermissionOption::new(PermissionOptionId::from(id.clone()), id, kind)
563        }
564        let options = vec![
565            option(PermissionOptionKind::AllowAlways),
566            option(PermissionOptionKind::AllowOnce),
567            option(PermissionOptionKind::RejectOnce),
568        ];
569
570        let permission_request =
571            RequestPermissionRequest::new(session_id, tool_call_update, options);
572
573        cx.send_request(permission_request)
574            .on_receiving_result(move |result| async move {
575                match result {
576                    Ok(response) => {
577                        agent
578                            .handle_confirmation(
579                                request_id,
580                                outcome_to_confirmation(&response.outcome),
581                            )
582                            .await;
583                        Ok(())
584                    }
585                    Err(e) => {
586                        error!(error = ?e, "permission request failed");
587                        agent
588                            .handle_confirmation(
589                                request_id,
590                                PermissionConfirmation {
591                                    principal_type: PrincipalType::Tool,
592                                    permission: Permission::Cancel,
593                                },
594                            )
595                            .await;
596                        Ok(())
597                    }
598                }
599            })?;
600
601        Ok(())
602    }
603}
604
605fn outcome_to_confirmation(outcome: &RequestPermissionOutcome) -> PermissionConfirmation {
606    let permission = match outcome {
607        RequestPermissionOutcome::Cancelled => Permission::Cancel,
608        RequestPermissionOutcome::Selected(selected) => {
609            match serde_json::from_value::<PermissionOptionKind>(serde_json::Value::String(
610                selected.option_id.to_string(),
611            )) {
612                Ok(PermissionOptionKind::AllowAlways) => Permission::AlwaysAllow,
613                Ok(PermissionOptionKind::AllowOnce) => Permission::AllowOnce,
614                Ok(PermissionOptionKind::RejectOnce | PermissionOptionKind::RejectAlways) => {
615                    Permission::DenyOnce
616                }
617                Ok(_) => Permission::Cancel, // Handle any future permission kinds
618                Err(_) => Permission::Cancel,
619            }
620        }
621        _ => Permission::Cancel, // Handle any future variants
622    };
623    PermissionConfirmation {
624        principal_type: PrincipalType::Tool,
625        permission,
626    }
627}
628
629fn build_tool_call_content(tool_result: &ToolResult<CallToolResult>) -> Vec<ToolCallContent> {
630    match tool_result {
631        Ok(result) => result
632            .content
633            .iter()
634            .filter_map(|content| match &content.raw {
635                RawContent::Text(val) => Some(ToolCallContent::Content(Content::new(
636                    ContentBlock::Text(TextContent::new(&val.text)),
637                ))),
638                RawContent::Image(val) => Some(ToolCallContent::Content(Content::new(
639                    ContentBlock::Image(ImageContent::new(&val.data, &val.mime_type)),
640                ))),
641                RawContent::Resource(val) => {
642                    let resource = match &val.resource {
643                        ResourceContents::TextResourceContents {
644                            mime_type,
645                            text,
646                            uri,
647                            ..
648                        } => {
649                            let mut r = TextResourceContents::new(text.clone(), uri.clone());
650                            if let Some(mt) = mime_type {
651                                r = r.mime_type(mt.clone());
652                            }
653                            EmbeddedResourceResource::TextResourceContents(r)
654                        }
655                        ResourceContents::BlobResourceContents {
656                            mime_type,
657                            blob,
658                            uri,
659                            ..
660                        } => {
661                            let mut r = BlobResourceContents::new(blob.clone(), uri.clone());
662                            if let Some(mt) = mime_type {
663                                r = r.mime_type(mt.clone());
664                            }
665                            EmbeddedResourceResource::BlobResourceContents(r)
666                        }
667                    };
668                    Some(ToolCallContent::Content(Content::new(
669                        ContentBlock::Resource(EmbeddedResource::new(resource)),
670                    )))
671                }
672                RawContent::Audio(_) => {
673                    // Audio content is not supported in ACP ContentBlock, skip it
674                    None
675                }
676                RawContent::ResourceLink(_) => {
677                    // ResourceLink content is not supported in ACP ContentBlock, skip it
678                    None
679                }
680            })
681            .collect(),
682        Err(_) => Vec::new(),
683    }
684}
685
686impl AsterAcpAgent {
687    async fn on_initialize(
688        &self,
689        args: InitializeRequest,
690    ) -> Result<InitializeResponse, sacp::Error> {
691        debug!(?args, "initialize request");
692
693        // Advertise Aster's capabilities
694        let capabilities = AgentCapabilities::new()
695            .load_session(true)
696            .prompt_capabilities(
697                PromptCapabilities::new()
698                    .image(true)
699                    .audio(false)
700                    .embedded_context(true),
701            )
702            .mcp_capabilities(McpCapabilities::new().http(true));
703        Ok(InitializeResponse::new(args.protocol_version).agent_capabilities(capabilities))
704    }
705
706    async fn on_new_session(
707        &self,
708        args: NewSessionRequest,
709    ) -> Result<NewSessionResponse, sacp::Error> {
710        debug!(?args, "new session request");
711
712        let aster_session = SessionManager::create_session(
713            std::env::current_dir().unwrap_or_default(),
714            "ACP Session".to_string(), // just an initial name - may be replaced by maybe_update_name
715            SessionType::User,
716        )
717        .await
718        .map_err(|e| {
719            sacp::Error::new(
720                sacp::ErrorCode::InternalError.into(),
721                format!("Failed to create session: {}", e),
722            )
723        })?;
724
725        let session = AsterAcpSession {
726            messages: Conversation::new_unvalidated(Vec::new()),
727            tool_requests: HashMap::new(),
728            cancel_token: None,
729        };
730
731        let mut sessions = self.sessions.lock().await;
732        sessions.insert(aster_session.id.clone(), session);
733
734        // Add MCP servers specified in the session request
735        for mcp_server in args.mcp_servers {
736            let config = match mcp_server_to_extension_config(mcp_server) {
737                Ok(c) => c,
738                Err(msg) => {
739                    return Err(sacp::Error::new(sacp::ErrorCode::InvalidParams.into(), msg));
740                }
741            };
742            let name = config.name().to_string();
743            if let Err(e) = self.agent.add_extension(config).await {
744                return Err(sacp::Error::new(
745                    sacp::ErrorCode::InternalError.into(),
746                    format!("Failed to add MCP server '{}': {}", name, e),
747                ));
748            }
749        }
750
751        info!(
752            session_id = %aster_session.id,
753            session_type = "acp",
754            "Session started"
755        );
756
757        Ok(NewSessionResponse::new(SessionId::new(aster_session.id)))
758    }
759
760    async fn on_load_session(
761        &self,
762        args: LoadSessionRequest,
763        cx: &JrConnectionCx<AgentToClient>,
764    ) -> Result<LoadSessionResponse, sacp::Error> {
765        debug!(?args, "load session request");
766
767        let session_id = args.session_id.0.to_string();
768
769        let aster_session = SessionManager::get_session(&session_id, true)
770            .await
771            .map_err(|e| {
772                sacp::Error::new(
773                    sacp::ErrorCode::InvalidParams.into(),
774                    format!("Failed to load session {}: {}", session_id, e),
775                )
776            })?;
777
778        let conversation = aster_session.conversation.ok_or_else(|| {
779            sacp::Error::new(
780                sacp::ErrorCode::InternalError.into(),
781                format!("Session {} has no conversation data", session_id),
782            )
783        })?;
784
785        SessionManager::update_session(&session_id)
786            .working_dir(args.cwd.clone())
787            .apply()
788            .await
789            .map_err(|e| {
790                sacp::Error::new(
791                    sacp::ErrorCode::InternalError.into(),
792                    format!("Failed to update session working directory: {}", e),
793                )
794            })?;
795
796        let mut session = AsterAcpSession {
797            messages: conversation.clone(),
798            tool_requests: HashMap::new(),
799            cancel_token: None,
800        };
801
802        // Replay conversation history to client
803        for message in conversation.messages() {
804            // Only replay user-visible messages
805            if !message.metadata.user_visible {
806                continue;
807            }
808
809            for content_item in &message.content {
810                match content_item {
811                    MessageContent::Text(text) => {
812                        let chunk =
813                            ContentChunk::new(ContentBlock::Text(TextContent::new(&text.text)));
814                        let update = match message.role {
815                            Role::User => SessionUpdate::UserMessageChunk(chunk),
816                            Role::Assistant => SessionUpdate::AgentMessageChunk(chunk),
817                        };
818                        cx.send_notification(SessionNotification::new(
819                            args.session_id.clone(),
820                            update,
821                        ))?;
822                    }
823                    MessageContent::ToolRequest(tool_request) => {
824                        self.handle_tool_request(tool_request, &args.session_id, &mut session, cx)
825                            .await?;
826                    }
827                    MessageContent::ToolResponse(tool_response) => {
828                        self.handle_tool_response(
829                            tool_response,
830                            &args.session_id,
831                            &mut session,
832                            cx,
833                        )
834                        .await?;
835                    }
836                    MessageContent::Thinking(thinking) => {
837                        cx.send_notification(SessionNotification::new(
838                            args.session_id.clone(),
839                            SessionUpdate::AgentThoughtChunk(ContentChunk::new(
840                                ContentBlock::Text(TextContent::new(&thinking.thinking)),
841                            )),
842                        ))?;
843                    }
844                    _ => {
845                        // Ignore other content types
846                    }
847                }
848            }
849        }
850
851        let mut sessions = self.sessions.lock().await;
852        sessions.insert(session_id.clone(), session);
853
854        info!(
855            session_id = %session_id,
856            session_type = "acp",
857            "Session loaded"
858        );
859
860        Ok(LoadSessionResponse::new())
861    }
862
863    async fn on_prompt(
864        &self,
865        args: PromptRequest,
866        cx: &JrConnectionCx<AgentToClient>,
867    ) -> Result<PromptResponse, sacp::Error> {
868        let session_id = args.session_id.0.to_string();
869        let cancel_token = CancellationToken::new();
870
871        {
872            let mut sessions = self.sessions.lock().await;
873            let session = sessions.get_mut(&session_id).ok_or_else(|| {
874                sacp::Error::new(
875                    sacp::ErrorCode::InvalidParams.into(),
876                    format!("Session not found: {}", session_id),
877                )
878            })?;
879            session.cancel_token = Some(cancel_token.clone());
880        }
881
882        let user_message = self.convert_acp_prompt_to_message(args.prompt);
883
884        let session_config = SessionConfig {
885            id: session_id.clone(),
886            schedule_id: None,
887            max_turns: None,
888            retry_config: None,
889            system_prompt: None,
890        };
891
892        let mut stream = self
893            .agent
894            .reply(user_message, session_config, Some(cancel_token.clone()))
895            .await
896            .map_err(|e| {
897                sacp::Error::new(
898                    sacp::ErrorCode::InternalError.into(),
899                    format!("Error getting agent reply: {}", e),
900                )
901            })?;
902
903        use futures::StreamExt;
904
905        let mut was_cancelled = false;
906
907        while let Some(event) = stream.next().await {
908            if cancel_token.is_cancelled() {
909                was_cancelled = true;
910                break;
911            }
912
913            match event {
914                Ok(aster::agents::AgentEvent::Message(message)) => {
915                    let mut sessions = self.sessions.lock().await;
916                    let session = sessions.get_mut(&session_id).ok_or_else(|| {
917                        sacp::Error::new(
918                            sacp::ErrorCode::InvalidParams.into(),
919                            format!("Session not found: {}", session_id),
920                        )
921                    })?;
922
923                    session.messages.push(message.clone());
924
925                    for content_item in &message.content {
926                        self.handle_message_content(content_item, &args.session_id, session, cx)
927                            .await?;
928                    }
929                }
930                Ok(_) => {}
931                Err(e) => {
932                    return Err(sacp::Error::new(
933                        sacp::ErrorCode::InternalError.into(),
934                        format!("Error in agent response stream: {}", e),
935                    ));
936                }
937            }
938        }
939
940        let mut sessions = self.sessions.lock().await;
941        if let Some(session) = sessions.get_mut(&session_id) {
942            session.cancel_token = None;
943        }
944
945        let stop_reason = if was_cancelled {
946            StopReason::Cancelled
947        } else {
948            StopReason::EndTurn
949        };
950        Ok(PromptResponse::new(stop_reason))
951    }
952
953    async fn on_cancel(&self, args: CancelNotification) -> Result<(), sacp::Error> {
954        debug!(?args, "cancel request");
955
956        let session_id = args.session_id.0.to_string();
957        let mut sessions = self.sessions.lock().await;
958
959        if let Some(session) = sessions.get_mut(&session_id) {
960            if let Some(ref token) = session.cancel_token {
961                info!(session_id = %session_id, "prompt cancelled");
962                token.cancel();
963            }
964        } else {
965            warn!(session_id = %session_id, "cancel request for unknown session");
966        }
967
968        Ok(())
969    }
970}
971
972struct AsterAcpHandler {
973    agent: Arc<AsterAcpAgent>,
974}
975
976impl JrMessageHandler for AsterAcpHandler {
977    type Link = AgentToClient;
978
979    fn describe_chain(&self) -> impl std::fmt::Debug {
980        "aster-acp"
981    }
982
983    async fn handle_message(
984        &mut self,
985        message: MessageCx,
986        cx: JrConnectionCx<AgentToClient>,
987    ) -> Result<Handled<MessageCx>, sacp::Error> {
988        use sacp::util::MatchMessageFrom;
989        use sacp::JrRequestCx;
990
991        MatchMessageFrom::new(message, &cx)
992            .if_request(
993                |req: InitializeRequest, req_cx: JrRequestCx<InitializeResponse>| async {
994                    req_cx.respond(self.agent.on_initialize(req).await?)
995                },
996            )
997            .await
998            .if_request(
999                |_req: AuthenticateRequest, req_cx: JrRequestCx<AuthenticateResponse>| async {
1000                    req_cx.respond(AuthenticateResponse::new())
1001                },
1002            )
1003            .await
1004            .if_request(
1005                |req: NewSessionRequest, req_cx: JrRequestCx<NewSessionResponse>| async {
1006                    req_cx.respond(self.agent.on_new_session(req).await?)
1007                },
1008            )
1009            .await
1010            .if_request(
1011                |req: LoadSessionRequest, req_cx: JrRequestCx<LoadSessionResponse>| async {
1012                    req_cx.respond(self.agent.on_load_session(req, &cx).await?)
1013                },
1014            )
1015            .await
1016            .if_request(
1017                |req: PromptRequest, req_cx: JrRequestCx<PromptResponse>| async {
1018                    // Spawn the prompt processing in a task so we don't block the event loop.
1019                    // This allows permission responses to be processed while the agent is working.
1020                    let agent = self.agent.clone();
1021                    let cx_clone = cx.clone();
1022                    cx.spawn(async move {
1023                        match agent.on_prompt(req, &cx_clone).await {
1024                            Ok(response) => {
1025                                req_cx.respond(response)?;
1026                            }
1027                            Err(e) => {
1028                                req_cx.respond_with_error(e)?;
1029                            }
1030                        }
1031                        Ok(())
1032                    })?;
1033                    Ok(())
1034                },
1035            )
1036            .await
1037            .if_notification(|notif: CancelNotification| async {
1038                self.agent.on_cancel(notif).await
1039            })
1040            .await
1041            .done()
1042    }
1043}
1044
1045pub async fn run_acp_agent(builtins: Vec<String>) -> Result<()> {
1046    info!("listening on stdio");
1047
1048    let outgoing = tokio::io::stdout().compat_write();
1049    let incoming = tokio::io::stdin().compat();
1050
1051    let agent = Arc::new(AsterAcpAgent::new(builtins).await?);
1052    let handler = AsterAcpHandler { agent };
1053
1054    AgentToClient::builder()
1055        .name("aster-acp")
1056        .with_handler(handler)
1057        .serve(ByteStreams::new(outgoing, incoming))
1058        .await?;
1059
1060    Ok(())
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065    use super::*;
1066    use sacp::schema::{
1067        EnvVariable, HttpHeader, McpServer, McpServerHttp, McpServerSse, McpServerStdio,
1068        ResourceLink, SelectedPermissionOutcome,
1069    };
1070    use std::io::Write;
1071    use tempfile::NamedTempFile;
1072    use test_case::test_case;
1073
1074    use crate::commands::acp::{
1075        format_tool_name, mcp_server_to_extension_config, read_resource_link,
1076    };
1077    use aster::agents::ExtensionConfig;
1078
1079    #[test_case(
1080        McpServer::Stdio(
1081            McpServerStdio::new("github", "/path/to/github-mcp-server")
1082                .args(vec!["stdio".into()])
1083                .env(vec![EnvVariable::new(
1084                    "GITHUB_PERSONAL_ACCESS_TOKEN",
1085                    "ghp_xxxxxxxxxxxx"
1086                )])
1087        ),
1088        Ok(ExtensionConfig::Stdio {
1089            name: "github".into(),
1090            description: String::new(),
1091            cmd: "/path/to/github-mcp-server".into(),
1092            args: vec!["stdio".into()],
1093            envs: Envs::new(
1094                [(
1095                    "GITHUB_PERSONAL_ACCESS_TOKEN".into(),
1096                    "ghp_xxxxxxxxxxxx".into()
1097                )]
1098                .into()
1099            ),
1100            env_keys: vec![],
1101            timeout: None,
1102            bundled: Some(false),
1103            available_tools: vec![],
1104        })
1105    )]
1106    #[test_case(
1107        McpServer::Http(
1108            McpServerHttp::new("github", "https://api.githubcopilot.com/mcp/")
1109                .headers(vec![HttpHeader::new("Authorization", "Bearer ghp_xxxxxxxxxxxx")])
1110        ),
1111        Ok(ExtensionConfig::StreamableHttp {
1112            name: "github".into(),
1113            description: String::new(),
1114            uri: "https://api.githubcopilot.com/mcp/".into(),
1115            envs: Envs::default(),
1116            env_keys: vec![],
1117            headers: HashMap::from([(
1118                "Authorization".into(),
1119                "Bearer ghp_xxxxxxxxxxxx".into()
1120            )]),
1121            timeout: None,
1122            bundled: Some(false),
1123            available_tools: vec![],
1124        })
1125    )]
1126    #[test_case(
1127        McpServer::Sse(McpServerSse::new("test-sse", "https://agent-fin.biodnd.com/sse")),
1128        Err("SSE is unsupported, migrate to streamable_http".to_string())
1129    )]
1130    fn test_mcp_server_to_extension_config(
1131        input: McpServer,
1132        expected: Result<ExtensionConfig, String>,
1133    ) {
1134        assert_eq!(mcp_server_to_extension_config(input), expected);
1135    }
1136
1137    fn new_resource_link(content: &str) -> anyhow::Result<(ResourceLink, NamedTempFile)> {
1138        let mut file = NamedTempFile::new()?;
1139        file.write_all(content.as_bytes())?;
1140
1141        let name = file
1142            .path()
1143            .file_name()
1144            .unwrap()
1145            .to_string_lossy()
1146            .to_string();
1147        let uri = format!("file://{}", file.path().to_str().unwrap());
1148        let link = ResourceLink::new(name, uri);
1149        Ok((link, file))
1150    }
1151
1152    #[test]
1153    fn test_read_resource_link_non_file_scheme() {
1154        let (link, file) = new_resource_link("print(\"hello, world\")").unwrap();
1155
1156        let result = read_resource_link(link).unwrap();
1157        let expected = format!(
1158            "
1159
1160# {}
1161```
1162print(\"hello, world\")
1163```",
1164            file.path().to_str().unwrap(),
1165        );
1166
1167        assert_eq!(result, expected,)
1168    }
1169
1170    #[test]
1171    fn test_format_tool_name_with_extension() {
1172        assert_eq!(
1173            format_tool_name("developer__text_editor"),
1174            "Developer: Text Editor"
1175        );
1176        assert_eq!(
1177            format_tool_name("platform__manage_extensions"),
1178            "Platform: Manage Extensions"
1179        );
1180        assert_eq!(format_tool_name("todo__write"), "Todo: Write");
1181    }
1182
1183    #[test]
1184    fn test_format_tool_name_without_extension() {
1185        assert_eq!(format_tool_name("simple_tool"), "Simple Tool");
1186        assert_eq!(format_tool_name("another_name"), "Another Name");
1187        assert_eq!(format_tool_name("single"), "Single");
1188    }
1189
1190    #[test]
1191    fn test_format_tool_name_edge_cases() {
1192        assert_eq!(format_tool_name(""), "");
1193        assert_eq!(format_tool_name("__"), ": ");
1194        assert_eq!(format_tool_name("extension__"), "Extension: ");
1195        assert_eq!(format_tool_name("__tool"), ": Tool");
1196    }
1197
1198    #[test_case(
1199        RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("allow_once")),
1200        PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::AllowOnce };
1201        "allow_once_maps_to_allow_once"
1202    )]
1203    #[test_case(
1204        RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("allow_always")),
1205        PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::AlwaysAllow };
1206        "allow_always_maps_to_always_allow"
1207    )]
1208    #[test_case(
1209        RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("reject_once")),
1210        PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::DenyOnce };
1211        "reject_once_maps_to_deny_once"
1212    )]
1213    #[test_case(
1214        RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("reject_always")),
1215        PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::DenyOnce };
1216        "reject_always_maps_to_deny_once"
1217    )]
1218    #[test_case(
1219        RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("unknown")),
1220        PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::Cancel };
1221        "unknown_option_maps_to_cancel"
1222    )]
1223    #[test_case(
1224        RequestPermissionOutcome::Cancelled,
1225        PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::Cancel };
1226        "cancelled_maps_to_cancel"
1227    )]
1228    fn test_outcome_to_confirmation(
1229        input: RequestPermissionOutcome,
1230        expected: PermissionConfirmation,
1231    ) {
1232        assert_eq!(outcome_to_confirmation(&input), expected);
1233    }
1234}