1#![allow(clippy::too_many_lines)]
23#![allow(clippy::significant_drop_tightening)]
24
25use crate::agent::{AbortHandle, AbortSignal, AgentEvent, AgentSession};
26use crate::agent_cx::AgentCx;
27use crate::auth::AuthStorage;
28use crate::compaction::ResolvedCompactionSettings;
29use crate::config::Config;
30use crate::error::{Error, Result};
31use crate::model::{AssistantMessage, AssistantMessageEvent, ContentBlock};
32use crate::models::ModelEntry;
33use crate::provider::StreamOptions;
34use crate::provider_metadata::provider_ids_match;
35use crate::providers;
36use crate::session::Session;
37use crate::tools::ToolRegistry;
38use asupersync::runtime::RuntimeHandle;
39use asupersync::sync::Mutex;
40use serde::{Deserialize, Serialize};
41use serde_json::{Value, json};
42use std::collections::HashMap;
43use std::io::{self, BufRead, Write};
44use std::path::PathBuf;
45use std::sync::Arc;
46use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
47
48#[derive(Debug, Clone, Deserialize)]
54struct JsonRpcRequest {
55 jsonrpc: String,
56 id: Option<Value>,
57 method: String,
58 #[serde(default)]
59 params: Value,
60}
61
62#[derive(Debug, Clone, Serialize)]
64struct JsonRpcResponse {
65 jsonrpc: String,
66 id: Value,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 result: Option<Value>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 error: Option<JsonRpcError>,
71}
72
73#[derive(Debug, Clone, Serialize)]
75struct JsonRpcNotification {
76 jsonrpc: String,
77 method: String,
78 params: Value,
79}
80
81#[derive(Debug, Clone, Serialize)]
83struct JsonRpcError {
84 code: i64,
85 message: String,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 data: Option<Value>,
88}
89
90const PARSE_ERROR: i64 = -32700;
92const INVALID_REQUEST: i64 = -32600;
93const METHOD_NOT_FOUND: i64 = -32601;
94const INVALID_PARAMS: i64 = -32602;
95const INTERNAL_ERROR: i64 = -32603;
96
97const SESSION_NOT_FOUND: i64 = -32001;
99const PROMPT_IN_PROGRESS: i64 = -32002;
100const PROMPT_NOT_FOUND: i64 = -32003;
101
102fn json_rpc_ok(id: Value, result: Value) -> String {
103 serde_json::to_string(&JsonRpcResponse {
104 jsonrpc: "2.0".to_string(),
105 id,
106 result: Some(result),
107 error: None,
108 })
109 .expect("serialize json-rpc response")
110}
111
112fn json_rpc_error(id: Value, code: i64, message: impl Into<String>) -> String {
113 serde_json::to_string(&JsonRpcResponse {
114 jsonrpc: "2.0".to_string(),
115 id,
116 result: None,
117 error: Some(JsonRpcError {
118 code,
119 message: message.into(),
120 data: None,
121 }),
122 })
123 .expect("serialize json-rpc error")
124}
125
126fn json_rpc_notification(method: &str, params: Value) -> String {
127 serde_json::to_string(&JsonRpcNotification {
128 jsonrpc: "2.0".to_string(),
129 method: method.to_string(),
130 params,
131 })
132 .expect("serialize json-rpc notification")
133}
134
135type AcpSessionsMap = Arc<Mutex<HashMap<String, Arc<Mutex<AcpSessionState>>>>>;
140
141#[derive(Debug, Clone, Serialize)]
146#[serde(rename_all = "camelCase")]
147struct AcpModel {
148 id: String,
149 name: String,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 provider: Option<String>,
152}
153
154#[derive(Debug, Clone, Serialize)]
156#[serde(rename_all = "camelCase")]
157struct AcpMode {
158 slug: String,
159 name: String,
160 description: String,
161}
162
163#[derive(Debug, Clone, Serialize)]
165#[serde(tag = "type", rename_all = "camelCase")]
166enum AcpContentItem {
167 #[serde(rename = "text")]
168 Text { text: String },
169 #[serde(rename = "thinking")]
170 Thinking { text: String },
171 #[serde(rename = "tool_use")]
172 ToolUse {
173 id: String,
174 name: String,
175 input: Value,
176 },
177}
178
179struct AcpSessionState {
188 agent_session: Option<AgentSession>,
191 cwd: PathBuf,
192 session_id: String,
193}
194
195#[derive(Clone)]
201pub struct AcpOptions {
202 pub config: Config,
203 pub available_models: Vec<ModelEntry>,
204 pub auth: AuthStorage,
205 pub runtime_handle: RuntimeHandle,
206}
207
208pub async fn run_stdio(options: AcpOptions) -> Result<()> {
213 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(256);
214 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
215
216 std::thread::spawn(move || {
218 let stdin = io::stdin();
219 let mut reader = io::BufReader::new(stdin.lock());
220 let mut line = String::new();
221 loop {
222 line.clear();
223 match reader.read_line(&mut line) {
224 Ok(0) | Err(_) => break,
225 Ok(_) => {
226 let trimmed = line.trim().to_string();
227 if trimmed.is_empty() {
228 continue;
229 }
230 let mut to_send = trimmed;
232 loop {
233 match in_tx.try_send(to_send) {
234 Ok(()) => break,
235 Err(asupersync::channel::mpsc::SendError::Full(unsent)) => {
236 to_send = unsent;
237 std::thread::sleep(std::time::Duration::from_millis(10));
238 }
239 Err(_) => return,
240 }
241 }
242 }
243 }
244 }
245 });
246
247 std::thread::spawn(move || {
249 let stdout = io::stdout();
250 let mut writer = io::BufWriter::new(stdout.lock());
251 for line in out_rx {
252 if writer.write_all(line.as_bytes()).is_err() {
253 break;
254 }
255 if writer.write_all(b"\n").is_err() {
256 break;
257 }
258 if writer.flush().is_err() {
259 break;
260 }
261 }
262 });
263
264 run(options, in_rx, out_tx).await
265}
266
267async fn run(
269 options: AcpOptions,
270 mut in_rx: asupersync::channel::mpsc::Receiver<String>,
271 out_tx: std::sync::mpsc::SyncSender<String>,
272) -> Result<()> {
273 let cx = AgentCx::for_current_or_request();
274 let sessions: AcpSessionsMap = Arc::new(Mutex::new(HashMap::new()));
275 let prompt_counter = Arc::new(AtomicU64::new(0));
276 let active_prompts: Arc<Mutex<HashMap<String, AbortHandle>>> =
277 Arc::new(Mutex::new(HashMap::new()));
278 let initialized = Arc::new(AtomicBool::new(false));
279
280 while let Ok(line) = in_rx.recv(&cx).await {
281 let request: JsonRpcRequest = match serde_json::from_str(&line) {
283 Ok(req) => req,
284 Err(err) => {
285 let _ = out_tx.send(json_rpc_error(
286 Value::Null,
287 PARSE_ERROR,
288 format!("Parse error: {err}"),
289 ));
290 continue;
291 }
292 };
293
294 if request.jsonrpc != "2.0" {
296 if let Some(ref id) = request.id {
297 let _ = out_tx.send(json_rpc_error(
298 id.clone(),
299 INVALID_REQUEST,
300 "Expected jsonrpc version 2.0",
301 ));
302 }
303 continue;
304 }
305
306 let id = request.id.clone().unwrap_or(Value::Null);
307
308 match request.method.as_str() {
309 "initialize" => {
310 let result = handle_initialize();
311 initialized.store(true, Ordering::SeqCst);
312 let _ = out_tx.send(json_rpc_ok(id, result));
313 }
314
315 "initialized" => {}
318
319 "shutdown" => {
321 let _ = out_tx.send(json_rpc_ok(id, json!(null)));
322 }
323
324 "exit" => {
326 break;
327 }
328
329 "session/new" => {
330 if !initialized.load(Ordering::SeqCst) {
331 let _ = out_tx.send(json_rpc_error(
332 id,
333 INVALID_REQUEST,
334 "Server not initialized. Call 'initialize' first.",
335 ));
336 continue;
337 }
338
339 match handle_session_new(&request.params, &options, &cx).await {
340 Ok((session_id, state)) => {
341 let models: Vec<AcpModel> = options
342 .available_models
343 .iter()
344 .map(|entry| AcpModel {
345 id: entry.model.id.clone(),
346 name: entry.model.name.clone(),
347 provider: Some(entry.model.provider.clone()),
348 })
349 .collect();
350
351 let modes = vec![
352 AcpMode {
353 slug: "agent".to_string(),
354 name: "Agent".to_string(),
355 description: "Full autonomous coding agent with tool access"
356 .to_string(),
357 },
358 AcpMode {
359 slug: "chat".to_string(),
360 name: "Chat".to_string(),
361 description: "Conversational mode without tool execution"
362 .to_string(),
363 },
364 ];
365
366 let state_arc = Arc::new(Mutex::new(state));
367 if let Ok(mut guard) = sessions.lock(&cx).await {
368 guard.insert(session_id.clone(), state_arc);
369 }
370
371 let _ = out_tx.send(json_rpc_ok(
372 id,
373 json!({
374 "sessionId": session_id,
375 "models": models,
376 "modes": modes,
377 }),
378 ));
379 }
380 Err(err) => {
381 let _ = out_tx.send(json_rpc_error(
382 id,
383 INTERNAL_ERROR,
384 format!("Failed to create session: {err}"),
385 ));
386 }
387 }
388 }
389
390 "prompt" => {
391 if !initialized.load(Ordering::SeqCst) {
392 let _ = out_tx.send(json_rpc_error(
393 id,
394 INVALID_REQUEST,
395 "Server not initialized",
396 ));
397 continue;
398 }
399
400 let session_id = request
401 .params
402 .get("sessionId")
403 .and_then(Value::as_str)
404 .map(String::from);
405 let message_text = request
406 .params
407 .get("message")
408 .and_then(Value::as_str)
409 .map(String::from);
410
411 let Some(session_id) = session_id else {
412 let _ = out_tx.send(json_rpc_error(
413 id,
414 INVALID_PARAMS,
415 "Missing required parameter: sessionId",
416 ));
417 continue;
418 };
419
420 let Some(message_text) = message_text else {
421 let _ = out_tx.send(json_rpc_error(
422 id,
423 INVALID_PARAMS,
424 "Missing required parameter: message",
425 ));
426 continue;
427 };
428
429 let session_state = {
430 sessions
431 .lock(&cx)
432 .await
433 .map_or_else(|_| None, |guard| guard.get(&session_id).cloned())
434 };
435
436 let Some(session_state) = session_state else {
437 let _ = out_tx.send(json_rpc_error(
438 id,
439 SESSION_NOT_FOUND,
440 format!("Session not found: {session_id}"),
441 ));
442 continue;
443 };
444
445 {
447 let has_active = active_prompts.lock(&cx).await.is_ok_and(|guard| {
448 guard
449 .keys()
450 .any(|k| k.starts_with(&format!("{session_id}:")))
451 });
452 if has_active {
453 let _ = out_tx.send(json_rpc_error(
454 id,
455 PROMPT_IN_PROGRESS,
456 format!("Session {session_id} already has an active prompt"),
457 ));
458 continue;
459 }
460 }
461
462 let prompt_seq = prompt_counter.fetch_add(1, Ordering::SeqCst);
464 let prompt_id = format!("{session_id}:prompt-{prompt_seq}");
465
466 let (abort_handle, abort_signal) = AbortHandle::new();
468 if let Ok(mut guard) = active_prompts.lock(&cx).await {
469 guard.insert(prompt_id.clone(), abort_handle);
470 }
471
472 let _ = out_tx.send(json_rpc_ok(
474 id,
475 json!({
476 "promptId": prompt_id,
477 }),
478 ));
479
480 let out_tx_prompt = out_tx.clone();
482 let active_prompts_cleanup = Arc::clone(&active_prompts);
483 let prompt_id_cleanup = prompt_id.clone();
484 let prompt_cx = cx.clone();
485 let prompt_session_id = session_id.clone();
486
487 options.runtime_handle.spawn(async move {
488 run_prompt(
489 session_state,
490 message_text,
491 abort_signal,
492 out_tx_prompt,
493 prompt_id.clone(),
494 prompt_session_id,
495 prompt_cx.clone(),
496 )
497 .await;
498
499 if let Ok(mut guard) = active_prompts_cleanup.lock(&prompt_cx).await {
501 guard.remove(&prompt_id_cleanup);
502 }
503 });
504 }
505
506 "cancel" => {
507 let prompt_id = request
508 .params
509 .get("promptId")
510 .and_then(Value::as_str)
511 .map(String::from);
512
513 let Some(prompt_id) = prompt_id else {
514 let _ = out_tx.send(json_rpc_error(
515 id,
516 INVALID_PARAMS,
517 "Missing required parameter: promptId",
518 ));
519 continue;
520 };
521
522 let aborted = active_prompts.lock(&cx).await.is_ok_and(|guard| {
523 guard.get(&prompt_id).is_some_and(|handle| {
524 handle.abort();
525 true
526 })
527 });
528
529 if aborted {
530 let _ = out_tx.send(json_rpc_ok(id, json!({ "cancelled": true })));
531 } else {
532 let _ = out_tx.send(json_rpc_error(
533 id,
534 PROMPT_NOT_FOUND,
535 format!("No active prompt with id: {prompt_id}"),
536 ));
537 }
538 }
539
540 "session/list" => {
541 let session_list: Vec<Value> = sessions.lock(&cx).await.map_or_else(
542 |_| Vec::new(),
543 |guard| {
544 guard
545 .keys()
546 .map(|sid| json!({ "sessionId": sid }))
547 .collect()
548 },
549 );
550
551 let _ = out_tx.send(json_rpc_ok(id, json!({ "sessions": session_list })));
552 }
553
554 "session/load" => {
555 let session_id = request
556 .params
557 .get("sessionId")
558 .and_then(Value::as_str)
559 .map(String::from);
560
561 let Some(session_id) = session_id else {
562 let _ = out_tx.send(json_rpc_error(
563 id,
564 INVALID_PARAMS,
565 "Missing required parameter: sessionId",
566 ));
567 continue;
568 };
569
570 let exists = sessions
571 .lock(&cx)
572 .await
573 .is_ok_and(|guard| guard.contains_key(&session_id));
574
575 if exists {
576 let models: Vec<AcpModel> = options
577 .available_models
578 .iter()
579 .map(|entry| AcpModel {
580 id: entry.model.id.clone(),
581 name: entry.model.name.clone(),
582 provider: Some(entry.model.provider.clone()),
583 })
584 .collect();
585
586 let _ = out_tx.send(json_rpc_ok(
587 id,
588 json!({
589 "sessionId": session_id,
590 "models": models,
591 }),
592 ));
593 } else {
594 let _ = out_tx.send(json_rpc_error(
595 id,
596 SESSION_NOT_FOUND,
597 format!("Session not found: {session_id}"),
598 ));
599 }
600 }
601
602 "session/resume" => {
603 let session_id = request
604 .params
605 .get("sessionId")
606 .and_then(Value::as_str)
607 .map(String::from);
608
609 let Some(session_id) = session_id else {
610 let _ = out_tx.send(json_rpc_error(
611 id,
612 INVALID_PARAMS,
613 "Missing required parameter: sessionId",
614 ));
615 continue;
616 };
617
618 let exists = sessions
619 .lock(&cx)
620 .await
621 .is_ok_and(|guard| guard.contains_key(&session_id));
622
623 if exists {
624 let _ = out_tx.send(json_rpc_ok(
625 id,
626 json!({
627 "sessionId": session_id,
628 "resumed": true,
629 }),
630 ));
631 } else {
632 let _ = out_tx.send(json_rpc_error(
633 id,
634 SESSION_NOT_FOUND,
635 format!("Session not found: {session_id}"),
636 ));
637 }
638 }
639
640 "read_text_file" => {
643 let path_str = match request.params.get("path").and_then(Value::as_str) {
644 Some(p) if !p.is_empty() => p,
645 _ => {
646 let _ = out_tx.send(json_rpc_error(
647 id,
648 INVALID_PARAMS,
649 "Missing or empty required parameter: path",
650 ));
651 continue;
652 }
653 };
654 let session_id = request.params.get("sessionId").and_then(Value::as_str);
655
656 if let Err(msg) = validate_file_path(path_str, session_id, &sessions, &cx).await {
657 let _ = out_tx.send(json_rpc_error(id, INVALID_PARAMS, msg));
658 continue;
659 }
660
661 let max_bytes = 10 * 1024 * 1024; match asupersync::fs::metadata(path_str).await {
663 Ok(meta) if meta.len() > max_bytes => {
664 let _ = out_tx.send(json_rpc_error(
665 id,
666 INTERNAL_ERROR,
667 format!(
668 "File too large ({} bytes). Maximum allowed via ACP is {} bytes.",
669 meta.len(),
670 max_bytes
671 ),
672 ));
673 continue;
674 }
675 _ => {}
676 }
677
678 match asupersync::fs::read(path_str).await {
679 Ok(bytes) => {
680 let contents = String::from_utf8_lossy(&bytes).into_owned();
681 let _ = out_tx.send(json_rpc_ok(id, json!({ "contents": contents })));
682 }
683 Err(err) => {
684 let _ = out_tx.send(json_rpc_error(
685 id,
686 INTERNAL_ERROR,
687 format!("Failed to read file: {err}"),
688 ));
689 }
690 }
691 }
692
693 "write_text_file" => {
694 let path_str = match request.params.get("path").and_then(Value::as_str) {
695 Some(p) if !p.is_empty() => p,
696 _ => {
697 let _ = out_tx.send(json_rpc_error(
698 id,
699 INVALID_PARAMS,
700 "Missing or empty required parameter: path",
701 ));
702 continue;
703 }
704 };
705 let Some(contents) = request.params.get("contents").and_then(Value::as_str) else {
706 let _ = out_tx.send(json_rpc_error(
707 id,
708 INVALID_PARAMS,
709 "Missing required parameter: contents",
710 ));
711 continue;
712 };
713 let session_id = request.params.get("sessionId").and_then(Value::as_str);
714
715 if let Err(msg) = validate_file_path(path_str, session_id, &sessions, &cx).await {
716 let _ = out_tx.send(json_rpc_error(id, INVALID_PARAMS, msg));
717 continue;
718 }
719
720 match asupersync::fs::write(path_str, contents.as_bytes()).await {
721 Ok(()) => {
722 let _ = out_tx.send(json_rpc_ok(id, json!({ "success": true })));
723 }
724 Err(err) => {
725 let _ = out_tx.send(json_rpc_error(
726 id,
727 INTERNAL_ERROR,
728 format!("Failed to write file: {err}"),
729 ));
730 }
731 }
732 }
733
734 _ => {
736 let _ = out_tx.send(json_rpc_error(
737 id,
738 METHOD_NOT_FOUND,
739 format!("Method not found: {}", request.method),
740 ));
741 }
742 }
743 }
744
745 Ok(())
746}
747
748async fn validate_file_path(
757 path_str: &str,
758 session_id: Option<&str>,
759 sessions: &AcpSessionsMap,
760 cx: &AgentCx,
761) -> std::result::Result<(), String> {
762 let resolved = if let Ok(p) = std::path::Path::new(path_str).canonicalize() {
763 p
764 } else {
765 let parent = std::path::Path::new(path_str).parent();
767 match parent.and_then(|p| p.canonicalize().ok()) {
768 Some(p) => p.join(
769 std::path::Path::new(path_str)
770 .file_name()
771 .unwrap_or_default(),
772 ),
773 None => {
774 return Err(format!(
775 "Path does not exist and parent is invalid: {path_str}"
776 ));
777 }
778 }
779 };
780
781 let guard = sessions
782 .lock(cx)
783 .await
784 .map_err(|e| format!("Lock failed: {e}"))?;
785
786 if guard.is_empty() {
787 return Err("No active sessions — cannot validate file path".to_string());
788 }
789
790 let allowed_cwds: Vec<PathBuf> = if let Some(sid) = session_id {
791 match guard.get(sid) {
792 Some(state) => {
793 if let Ok(s) = state.lock(cx).await {
794 vec![s.cwd.clone()]
795 } else {
796 return Err("Session lock failed".to_string());
797 }
798 }
799 None => return Err(format!("Session not found: {sid}")),
800 }
801 } else {
802 let mut cwds = Vec::new();
803 for state in guard.values() {
804 if let Ok(s) = state.lock(cx).await {
805 cwds.push(s.cwd.clone());
806 }
807 }
808 cwds
809 };
810
811 for cwd in &allowed_cwds {
813 if let Ok(canonical_cwd) = cwd.canonicalize() {
814 if resolved.starts_with(&canonical_cwd) {
815 return Ok(());
816 }
817 }
818 if resolved.starts_with(cwd) {
820 return Ok(());
821 }
822 }
823
824 Err(format!(
825 "Path '{path_str}' is outside all session working directories",
826 ))
827}
828
829fn handle_initialize() -> Value {
834 let version = env!("CARGO_PKG_VERSION");
835 json!({
836 "protocolVersion": "2025-01-01",
837 "serverInfo": {
838 "name": "pi-agent",
839 "version": version,
840 },
841 "capabilities": {
842 "streaming": true,
843 "toolApproval": false,
844 },
845 })
846}
847
848fn select_acp_model_entry(config: &Config, available_models: &[ModelEntry]) -> Option<ModelEntry> {
849 if let (Some(default_provider), Some(default_model)) = (
850 config.default_provider.as_deref(),
851 config.default_model.as_deref(),
852 ) {
853 if let Some(entry) = available_models.iter().find(|entry| {
854 provider_ids_match(&entry.model.provider, default_provider)
855 && entry.model.id.eq_ignore_ascii_case(default_model)
856 }) {
857 return Some(entry.clone());
858 }
859 }
860
861 if let Some(default_provider) = config.default_provider.as_deref() {
862 if let Some(entry) = available_models
863 .iter()
864 .find(|entry| provider_ids_match(&entry.model.provider, default_provider))
865 {
866 return Some(entry.clone());
867 }
868 }
869
870 if let Some(default_model) = config.default_model.as_deref() {
871 if let Some(entry) = available_models
872 .iter()
873 .find(|entry| entry.model.id.eq_ignore_ascii_case(default_model))
874 {
875 return Some(entry.clone());
876 }
877 }
878
879 available_models.first().cloned()
880}
881
882fn resolve_acp_thinking_level(
883 config: &Config,
884 model_entry: &ModelEntry,
885) -> crate::model::ThinkingLevel {
886 let requested = config
887 .default_thinking_level
888 .as_deref()
889 .and_then(|value| value.parse().ok())
890 .unwrap_or(crate::model::ThinkingLevel::XHigh);
891 model_entry.clamp_thinking_level(requested)
892}
893
894fn build_acp_system_prompt(cwd: &std::path::Path, enabled_tools: &[&str]) -> String {
896 use std::fmt::Write as _;
897
898 let tool_descriptions = [
899 ("read", "Read file contents"),
900 ("bash", "Execute bash commands"),
901 ("edit", "Make surgical edits to files"),
902 ("write", "Write file contents"),
903 ("grep", "Search file contents with regex"),
904 ("find", "Find files by name pattern"),
905 ("ls", "List directory contents"),
906 ];
907
908 let mut prompt = String::from(
909 "You are a helpful AI coding assistant integrated into the user's editor via ACP (Agent Client Protocol). \
910 You have access to the following tools:\n\n",
911 );
912
913 for (name, description) in &tool_descriptions {
914 if enabled_tools.contains(name) {
915 let _ = writeln!(prompt, "- **{name}**: {description}");
916 }
917 }
918
919 prompt.push_str(
920 "\nUse these tools to help the user with coding tasks. \
921 Be concise and precise. When making file changes, explain what you're doing.\n",
922 );
923
924 for filename in &["pi.md", "AGENTS.md", ".pi"] {
926 let path = cwd.join(filename);
927 if path.is_file() {
928 if let Ok(content) = std::fs::read_to_string(&path) {
929 let _ = write!(prompt, "\n## {filename}\n\n{content}\n\n");
930 }
931 }
932 }
933
934 let date_time = chrono::Utc::now()
935 .format("%Y-%m-%d %H:%M:%S UTC")
936 .to_string();
937 let _ = write!(prompt, "\nCurrent date and time: {date_time}");
938 let _ = write!(prompt, "\nCurrent working directory: {}", cwd.display());
939
940 prompt
941}
942
943async fn handle_session_new(
944 params: &Value,
945 options: &AcpOptions,
946 _cx: &AgentCx,
947) -> Result<(String, AcpSessionState)> {
948 let cwd = params.get("cwd").and_then(Value::as_str).map_or_else(
949 || std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
950 PathBuf::from,
951 );
952
953 let mut session = Session::in_memory();
955 session.header.cwd = cwd.display().to_string();
956 let session_id = session.header.id.clone();
957
958 let enabled_tools: Vec<&str> = vec!["read", "bash", "edit", "write", "grep", "find", "ls"];
960 let tools = ToolRegistry::new(&enabled_tools, &cwd, Some(&options.config));
961
962 let model_entry = select_acp_model_entry(&options.config, &options.available_models)
965 .ok_or_else(|| Error::provider("acp", "No models available"))?;
966
967 let provider = providers::create_provider(&model_entry, None)
968 .map_err(|e| Error::provider("acp", e.to_string()))?;
969
970 let system_prompt = build_acp_system_prompt(&cwd, &enabled_tools);
972
973 let api_key = options
975 .auth
976 .resolve_api_key(&model_entry.model.provider, None)
977 .or_else(|| model_entry.api_key.clone())
978 .and_then(|k| {
979 let trimmed = k.trim();
980 (!trimmed.is_empty()).then(|| trimmed.to_string())
981 });
982
983 let stream_options = StreamOptions {
984 api_key,
985 thinking_level: Some(resolve_acp_thinking_level(&options.config, &model_entry)),
986 headers: model_entry.headers.clone(),
987 ..StreamOptions::default()
988 };
989
990 let agent_config = crate::agent::AgentConfig {
991 system_prompt: Some(system_prompt),
992 max_tool_iterations: 50,
993 stream_options,
994 block_images: options.config.image_block_images(),
995 fail_closed_hooks: options.config.fail_closed_hooks(),
996 };
997
998 let agent = crate::agent::Agent::new(provider, tools, agent_config);
999 let session_arc = Arc::new(Mutex::new(session));
1000 let compaction_settings = ResolvedCompactionSettings {
1001 enabled: options.config.compaction_enabled(),
1002 reserve_tokens: options.config.compaction_reserve_tokens(),
1003 keep_recent_tokens: options.config.compaction_keep_recent_tokens(),
1004 context_window_tokens: if model_entry.model.context_window == 0 {
1005 ResolvedCompactionSettings::default().context_window_tokens
1006 } else {
1007 model_entry.model.context_window
1008 },
1009 };
1010
1011 let agent_session = AgentSession::new(agent, session_arc, false, compaction_settings)
1012 .with_runtime_handle(options.runtime_handle.clone());
1013
1014 Ok((
1015 session_id.clone(),
1016 AcpSessionState {
1017 agent_session: Some(agent_session),
1018 cwd,
1019 session_id,
1020 },
1021 ))
1022}
1023
1024async fn run_prompt(
1026 session_state: Arc<Mutex<AcpSessionState>>,
1027 message: String,
1028 abort_signal: AbortSignal,
1029 out_tx: std::sync::mpsc::SyncSender<String>,
1030 prompt_id: String,
1031 session_id: String,
1032 cx: AgentCx,
1033) {
1034 let out_tx_events = out_tx.clone();
1035 let prompt_id_events = prompt_id.clone();
1036 let session_id_events = session_id.clone();
1037
1038 let event_handler = build_acp_event_handler(out_tx_events, prompt_id_events, session_id_events);
1040
1041 let mut agent_session = {
1047 let mut guard = match session_state.lock(&cx).await {
1048 Ok(guard) => guard,
1049 Err(err) => {
1050 let _ = out_tx.send(json_rpc_notification(
1051 "prompt/end",
1052 json!({
1053 "promptId": prompt_id,
1054 "sessionId": session_id,
1055 "error": format!("Session lock failed: {err}"),
1056 }),
1057 ));
1058 return;
1059 }
1060 };
1061 let Some(agent) = guard.agent_session.take() else {
1062 let _ = out_tx.send(json_rpc_notification(
1063 "prompt/end",
1064 json!({
1065 "promptId": prompt_id,
1066 "sessionId": session_id,
1067 "error": "Session is busy (agent_session unavailable)",
1068 }),
1069 ));
1070 return;
1071 };
1072 agent
1073 };
1074
1075 let result = agent_session
1076 .run_text_with_abort(message, Some(abort_signal), event_handler)
1077 .await;
1078
1079 if let Ok(mut guard) = session_state.lock(&cx).await {
1081 guard.agent_session = Some(agent_session);
1082 }
1083
1084 match result {
1086 Ok(ref msg) => {
1087 let content = assistant_message_to_acp_content(msg);
1088 let _ = out_tx.send(json_rpc_notification(
1089 "prompt/end",
1090 json!({
1091 "promptId": prompt_id,
1092 "sessionId": session_id,
1093 "content": content,
1094 "stopReason": serde_json::to_value(msg.stop_reason)
1095 .unwrap_or_else(|_| json!("unknown")),
1096 }),
1097 ));
1098 }
1099 Err(ref err) => {
1100 let _ = out_tx.send(json_rpc_notification(
1101 "prompt/end",
1102 json!({
1103 "promptId": prompt_id,
1104 "sessionId": session_id,
1105 "error": err.to_string(),
1106 }),
1107 ));
1108 }
1109 }
1110}
1111
1112fn build_acp_event_handler(
1114 out_tx: std::sync::mpsc::SyncSender<String>,
1115 prompt_id: String,
1116 session_id: String,
1117) -> impl Fn(AgentEvent) + Send + Sync + 'static {
1118 move |event: AgentEvent| {
1119 let notification = match &event {
1120 AgentEvent::MessageUpdate {
1121 assistant_message_event,
1122 ..
1123 } => match assistant_message_event {
1124 AssistantMessageEvent::TextDelta { delta, .. } => Some(json_rpc_notification(
1125 "prompt/progress",
1126 json!({
1127 "promptId": prompt_id,
1128 "sessionId": session_id,
1129 "kind": "textDelta",
1130 "content": [{
1131 "type": "text",
1132 "text": delta,
1133 }],
1134 }),
1135 )),
1136 AssistantMessageEvent::TextEnd { content, .. } => Some(json_rpc_notification(
1137 "prompt/progress",
1138 json!({
1139 "promptId": prompt_id,
1140 "sessionId": session_id,
1141 "kind": "textEnd",
1142 "content": [{
1143 "type": "text",
1144 "text": content,
1145 }],
1146 }),
1147 )),
1148 AssistantMessageEvent::ThinkingDelta { delta, .. } => Some(json_rpc_notification(
1149 "prompt/progress",
1150 json!({
1151 "promptId": prompt_id,
1152 "sessionId": session_id,
1153 "kind": "thinkingDelta",
1154 "content": [{
1155 "type": "thinking",
1156 "text": delta,
1157 }],
1158 }),
1159 )),
1160 AssistantMessageEvent::ToolCallEnd { tool_call, .. } => {
1161 Some(json_rpc_notification(
1162 "prompt/progress",
1163 json!({
1164 "promptId": prompt_id,
1165 "sessionId": session_id,
1166 "kind": "toolUse",
1167 "content": [{
1168 "type": "tool_use",
1169 "id": tool_call.id,
1170 "name": tool_call.name,
1171 "input": tool_call.arguments,
1172 }],
1173 }),
1174 ))
1175 }
1176 _ => None,
1177 },
1178
1179 AgentEvent::ToolExecutionStart {
1180 tool_call_id,
1181 tool_name,
1182 args,
1183 } => Some(json_rpc_notification(
1184 "prompt/progress",
1185 json!({
1186 "promptId": prompt_id,
1187 "sessionId": session_id,
1188 "kind": "toolExecutionStart",
1189 "toolCallId": tool_call_id,
1190 "toolName": tool_name,
1191 "args": args,
1192 }),
1193 )),
1194
1195 AgentEvent::ToolExecutionEnd {
1196 tool_call_id,
1197 tool_name,
1198 result,
1199 is_error,
1200 } => {
1201 let content_text = result
1202 .content
1203 .iter()
1204 .filter_map(|block| match block {
1205 ContentBlock::Text(t) => Some(t.text.as_str()),
1206 _ => None,
1207 })
1208 .collect::<Vec<_>>()
1209 .join("\n");
1210
1211 Some(json_rpc_notification(
1212 "prompt/progress",
1213 json!({
1214 "promptId": prompt_id,
1215 "sessionId": session_id,
1216 "kind": "toolResult",
1217 "toolName": tool_name,
1218 "content": [{
1219 "type": "tool_result",
1220 "toolUseId": tool_call_id,
1221 "content": content_text,
1222 "isError": is_error,
1223 }],
1224 }),
1225 ))
1226 }
1227
1228 AgentEvent::TurnStart { turn_index, .. } => Some(json_rpc_notification(
1229 "prompt/progress",
1230 json!({
1231 "promptId": prompt_id,
1232 "sessionId": session_id,
1233 "kind": "turnStart",
1234 "turnIndex": turn_index,
1235 }),
1236 )),
1237
1238 AgentEvent::TurnEnd { turn_index, .. } => Some(json_rpc_notification(
1239 "prompt/progress",
1240 json!({
1241 "promptId": prompt_id,
1242 "sessionId": session_id,
1243 "kind": "turnEnd",
1244 "turnIndex": turn_index,
1245 }),
1246 )),
1247
1248 AgentEvent::AgentStart { .. } => Some(json_rpc_notification(
1249 "prompt/progress",
1250 json!({
1251 "promptId": prompt_id,
1252 "sessionId": session_id,
1253 "kind": "agentStart",
1254 }),
1255 )),
1256
1257 AgentEvent::AgentEnd { error, .. } => Some(json_rpc_notification(
1258 "prompt/progress",
1259 json!({
1260 "promptId": prompt_id,
1261 "sessionId": session_id,
1262 "kind": "agentEnd",
1263 "error": error,
1264 }),
1265 )),
1266
1267 _ => None,
1269 };
1270
1271 if let Some(notif) = notification {
1272 let _ = out_tx.send(notif);
1273 }
1274 }
1275}
1276
1277fn assistant_message_to_acp_content(msg: &AssistantMessage) -> Vec<AcpContentItem> {
1279 let mut items = Vec::new();
1280 for block in &msg.content {
1281 match block {
1282 ContentBlock::Text(t) => {
1283 items.push(AcpContentItem::Text {
1284 text: t.text.clone(),
1285 });
1286 }
1287 ContentBlock::Thinking(t) => {
1288 items.push(AcpContentItem::Thinking {
1289 text: t.thinking.clone(),
1290 });
1291 }
1292 ContentBlock::ToolCall(tc) => {
1293 items.push(AcpContentItem::ToolUse {
1294 id: tc.id.clone(),
1295 name: tc.name.clone(),
1296 input: tc.arguments.clone(),
1297 });
1298 }
1299 ContentBlock::Image(_) => {
1300 }
1302 }
1303 }
1304 items
1305}
1306
1307#[cfg(test)]
1312mod tests {
1313 use super::*;
1314 use crate::provider::{InputType, Model, ModelCost};
1315 use std::collections::HashMap;
1316
1317 fn test_model_entry(provider: &str, id: &str) -> ModelEntry {
1318 ModelEntry {
1319 model: Model {
1320 id: id.to_string(),
1321 name: id.to_string(),
1322 api: "openai-responses".to_string(),
1323 provider: provider.to_string(),
1324 base_url: "https://example.invalid".to_string(),
1325 reasoning: true,
1326 input: vec![InputType::Text],
1327 cost: ModelCost {
1328 input: 0.0,
1329 output: 0.0,
1330 cache_read: 0.0,
1331 cache_write: 0.0,
1332 },
1333 context_window: 128_000,
1334 max_tokens: 8_192,
1335 headers: HashMap::new(),
1336 },
1337 api_key: None,
1338 headers: HashMap::new(),
1339 auth_header: true,
1340 compat: None,
1341 oauth_config: None,
1342 }
1343 }
1344
1345 #[test]
1346 fn json_rpc_ok_response_format() {
1347 let response = json_rpc_ok(Value::Number(1.into()), json!({"key": "value"}));
1348 let parsed: Value = serde_json::from_str(&response).expect("valid json");
1349 assert_eq!(parsed["jsonrpc"], "2.0");
1350 assert_eq!(parsed["id"], 1);
1351 assert_eq!(parsed["result"]["key"], "value");
1352 assert!(parsed.get("error").is_none());
1353 }
1354
1355 #[test]
1356 fn json_rpc_error_response_format() {
1357 let response = json_rpc_error(Value::String("test-id".into()), PARSE_ERROR, "bad json");
1358 let parsed: Value = serde_json::from_str(&response).expect("valid json");
1359 assert_eq!(parsed["jsonrpc"], "2.0");
1360 assert_eq!(parsed["id"], "test-id");
1361 assert!(parsed.get("result").is_none());
1362 assert_eq!(parsed["error"]["code"], PARSE_ERROR);
1363 assert_eq!(parsed["error"]["message"], "bad json");
1364 }
1365
1366 #[test]
1367 fn json_rpc_notification_format() {
1368 let notif = json_rpc_notification(
1369 "prompt/progress",
1370 json!({"promptId": "p1", "kind": "textDelta"}),
1371 );
1372 let parsed: Value = serde_json::from_str(¬if).expect("valid json");
1373 assert_eq!(parsed["jsonrpc"], "2.0");
1374 assert_eq!(parsed["method"], "prompt/progress");
1375 assert_eq!(parsed["params"]["promptId"], "p1");
1376 assert!(parsed.get("id").is_none());
1377 }
1378
1379 #[test]
1380 fn handle_initialize_returns_correct_shape() {
1381 let result = handle_initialize();
1382
1383 assert_eq!(result["protocolVersion"], "2025-01-01");
1384 assert_eq!(result["serverInfo"]["name"], "pi-agent");
1385 assert_eq!(result["serverInfo"]["version"], env!("CARGO_PKG_VERSION"));
1386 assert!(result["capabilities"]["streaming"].as_bool().unwrap());
1387 assert!(!result["capabilities"]["toolApproval"].as_bool().unwrap());
1388 }
1389
1390 #[test]
1391 fn select_acp_model_entry_prefers_exact_configured_model() {
1392 let config = Config {
1393 default_provider: Some("anthropic".to_string()),
1394 default_model: Some("claude-opus-4-5".to_string()),
1395 ..Config::default()
1396 };
1397 let available = vec![
1398 test_model_entry("openai", "gpt-5.2"),
1399 test_model_entry("anthropic", "claude-opus-4-5"),
1400 ];
1401
1402 let selected = select_acp_model_entry(&config, &available).expect("selected model");
1403
1404 assert_eq!(selected.model.provider, "anthropic");
1405 assert_eq!(selected.model.id, "claude-opus-4-5");
1406 }
1407
1408 #[test]
1409 fn select_acp_model_entry_prefers_default_provider_when_model_is_unset() {
1410 let config = Config {
1411 default_provider: Some("anthropic".to_string()),
1412 ..Config::default()
1413 };
1414 let available = vec![
1415 test_model_entry("openai", "gpt-5.2"),
1416 test_model_entry("anthropic", "claude-sonnet-4"),
1417 ];
1418
1419 let selected = select_acp_model_entry(&config, &available).expect("selected model");
1420
1421 assert_eq!(selected.model.provider, "anthropic");
1422 assert_eq!(selected.model.id, "claude-sonnet-4");
1423 }
1424
1425 #[test]
1426 fn select_acp_model_entry_prefers_default_model_when_provider_is_unset() {
1427 let config = Config {
1428 default_model: Some("gpt-5.2".to_string()),
1429 ..Config::default()
1430 };
1431 let available = vec![
1432 test_model_entry("anthropic", "claude-sonnet-4"),
1433 test_model_entry("openai", "gpt-5.2"),
1434 ];
1435
1436 let selected = select_acp_model_entry(&config, &available).expect("selected model");
1437
1438 assert_eq!(selected.model.provider, "openai");
1439 assert_eq!(selected.model.id, "gpt-5.2");
1440 }
1441
1442 #[test]
1443 fn select_acp_model_entry_matches_provider_aliases() {
1444 let config = Config {
1445 default_provider: Some("gemini-cli".to_string()),
1446 default_model: Some("gemini-2.5-pro".to_string()),
1447 ..Config::default()
1448 };
1449 let available = vec![
1450 test_model_entry("openai", "gpt-5.2"),
1451 test_model_entry("google-gemini-cli", "gemini-2.5-pro"),
1452 ];
1453
1454 let selected = select_acp_model_entry(&config, &available).expect("selected model");
1455
1456 assert_eq!(selected.model.provider, "google-gemini-cli");
1457 assert_eq!(selected.model.id, "gemini-2.5-pro");
1458 }
1459
1460 #[test]
1461 fn select_acp_model_entry_falls_back_to_first_available_model() {
1462 let available = vec![
1463 test_model_entry("openai", "gpt-5.2"),
1464 test_model_entry("anthropic", "claude-sonnet-4"),
1465 ];
1466
1467 let selected =
1468 select_acp_model_entry(&Config::default(), &available).expect("selected model");
1469
1470 assert_eq!(selected.model.provider, "openai");
1471 assert_eq!(selected.model.id, "gpt-5.2");
1472 }
1473
1474 #[test]
1475 fn resolve_acp_thinking_level_defaults_to_highest_supported_level() {
1476 let config = Config::default();
1477 let model_entry = test_model_entry("openai", "gpt-5.2");
1478
1479 let thinking = resolve_acp_thinking_level(&config, &model_entry);
1480
1481 assert_eq!(thinking, crate::model::ThinkingLevel::XHigh);
1482 }
1483
1484 #[test]
1485 fn resolve_acp_thinking_level_clamps_non_reasoning_models_to_off() {
1486 let config = Config::default();
1487 let mut model_entry = test_model_entry("ollama", "llama3.2");
1488 model_entry.model.reasoning = false;
1489
1490 let thinking = resolve_acp_thinking_level(&config, &model_entry);
1491
1492 assert_eq!(thinking, crate::model::ThinkingLevel::Off);
1493 }
1494
1495 #[test]
1496 fn assistant_message_to_acp_content_converts_blocks() {
1497 use crate::model::{TextContent, ToolCall};
1498
1499 let msg = AssistantMessage {
1500 content: vec![
1501 ContentBlock::Text(TextContent::new("Hello")),
1502 ContentBlock::ToolCall(ToolCall {
1503 id: "tc1".into(),
1504 name: "read".into(),
1505 arguments: json!({"path": "/tmp/test.txt"}),
1506 thought_signature: None,
1507 }),
1508 ],
1509 ..Default::default()
1510 };
1511
1512 let items = assistant_message_to_acp_content(&msg);
1513 assert_eq!(items.len(), 2);
1514
1515 match &items[0] {
1516 AcpContentItem::Text { text } => assert_eq!(text, "Hello"),
1517 _ => panic!("Expected Text item"),
1518 }
1519
1520 match &items[1] {
1521 AcpContentItem::ToolUse { id, name, input } => {
1522 assert_eq!(id, "tc1");
1523 assert_eq!(name, "read");
1524 assert_eq!(input["path"], "/tmp/test.txt");
1525 }
1526 _ => panic!("Expected ToolUse item"),
1527 }
1528 }
1529}