use std::cell::RefCell;
use std::rc::Rc;
use agent_client_protocol::{Error, ExtRequest, ExtResponse};
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use crate::config::Config;
use crate::session::cancellation::SessionCancellation;
use crate::session::chat::session::commands::{process_command, CommandResult};
use crate::session::chat::session::ChatSession;
pub const COMMAND_NAMESPACE: &str = "octomind/command";
#[derive(Debug, Deserialize)]
pub struct CommandRequest {
pub session_id: String,
pub command: String,
#[serde(default)]
pub args: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct CommandResponse {
pub success: bool,
pub output: Option<serde_json::Value>,
pub error: Option<String>,
}
pub async fn execute_command(
request: &CommandRequest,
sessions: &Rc<RefCell<std::collections::HashMap<String, (ChatSession, std::path::PathBuf)>>>,
config: &RefCell<Config>,
role: &str,
cancellations: &Rc<RefCell<std::collections::HashMap<String, SessionCancellation>>>,
) -> CommandResponse {
let (mut chat_session, session_cwd) = match sessions.borrow_mut().remove(&request.session_id) {
Some(s) => s,
None => {
return CommandResponse {
success: false,
output: None,
error: Some(format!("session not found: {}", request.session_id)),
};
}
};
crate::mcp::set_session_working_directory(session_cwd.clone());
let operation_rx = {
let mut cancellations = cancellations.borrow_mut();
if let Some(c) = cancellations.get_mut(&request.session_id) {
c.new_operation()
} else {
let c = SessionCancellation::new();
cancellations.insert(request.session_id.clone(), c);
cancellations
.get_mut(&request.session_id)
.unwrap()
.new_operation()
}
};
let full_command = if request.args.is_empty() {
request.command.clone()
} else {
format!("{} {}", request.command, request.args.join(" "))
};
let mut config_clone = config.borrow().clone();
let result = process_command(
&mut chat_session,
&full_command,
&mut config_clone,
role,
operation_rx,
)
.await;
sessions
.borrow_mut()
.insert(request.session_id.clone(), (chat_session, session_cwd));
match result {
Ok(CommandResult::Handled) => CommandResponse {
success: true,
output: None,
error: None,
},
Ok(CommandResult::HandledWithOutput(output)) => CommandResponse {
success: true,
output: Some(output.to_json()),
error: None,
},
Ok(CommandResult::Exit) => CommandResponse {
success: true,
output: Some(serde_json::json!({ "action": "exit" })),
error: None,
},
Ok(CommandResult::TreatAsUserInput) => CommandResponse {
success: false,
output: None,
error: Some(format!("Unknown command: {}", request.command)),
},
Err(e) => CommandResponse {
success: false,
output: None,
error: Some(e.to_string()),
},
}
}
pub async fn handle_ext_method(
request: ExtRequest,
sessions: &Rc<RefCell<std::collections::HashMap<String, (ChatSession, std::path::PathBuf)>>>,
config: &RefCell<Config>,
role: &str,
cancellations: &Rc<RefCell<std::collections::HashMap<String, SessionCancellation>>>,
) -> Result<ExtResponse, Error> {
if !request.method.starts_with(COMMAND_NAMESPACE) {
return Err(Error::method_not_found());
}
let command_request: CommandRequest = match serde_json::from_str(request.params.get()) {
Ok(req) => req,
Err(e) => {
let response = CommandResponse {
success: false,
output: None,
error: Some(format!("Invalid request: {}", e)),
};
let raw = RawValue::from_string(serde_json::to_string(&response).unwrap()).unwrap();
return Ok(ExtResponse::new(std::sync::Arc::from(raw)));
}
};
let response = execute_command(&command_request, sessions, config, role, cancellations).await;
let raw = RawValue::from_string(serde_json::to_string(&response).unwrap()).unwrap();
Ok(ExtResponse::new(std::sync::Arc::from(raw)))
}