use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum SessionState {
#[default]
Created,
Active,
AwaitingInput,
Completed,
Cancelled,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AcpSession {
pub session_id: String,
pub state: SessionState,
pub created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_activity_at: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
#[serde(default)]
pub turn_count: u32,
}
impl AcpSession {
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
state: SessionState::Created,
created_at: chrono::Utc::now().to_rfc3339(),
last_activity_at: None,
metadata: HashMap::new(),
turn_count: 0,
}
}
pub fn set_state(&mut self, state: SessionState) {
self.state = state;
self.last_activity_at = Some(chrono::Utc::now().to_rfc3339());
}
pub fn increment_turn(&mut self) {
self.turn_count += 1;
self.last_activity_at = Some(chrono::Utc::now().to_rfc3339());
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SessionNewParams {
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workspace: Option<WorkspaceContext>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_preferences: Option<ModelPreferences>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionNewResult {
pub session_id: String,
#[serde(default)]
pub state: SessionState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionLoadParams {
pub session_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionLoadResult {
pub session: AcpSession,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub history: Vec<ConversationTurn>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionPromptParams {
pub session_id: String,
pub content: Vec<PromptContent>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PromptContent {
Text {
text: String,
},
Image {
data: String,
mime_type: String,
#[serde(default)]
is_url: bool,
},
Context {
path: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
language: Option<String>,
},
}
impl PromptContent {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn context(path: impl Into<String>, content: impl Into<String>) -> Self {
Self::Context {
path: path.into(),
content: content.into(),
language: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionPromptResult {
pub turn_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCallRecord>,
pub status: TurnStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TurnStatus {
Completed,
Cancelled,
Failed,
AwaitingInput,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RequestPermissionParams {
pub session_id: String,
pub tool_call: ToolCallRecord,
pub options: Vec<PermissionOption>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionOption {
pub id: String,
pub label: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "outcome", rename_all = "snake_case")]
pub enum RequestPermissionResult {
Selected {
option_id: String,
},
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionCancelParams {
pub session_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub turn_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionUpdateNotification {
pub session_id: String,
pub turn_id: String,
#[serde(flatten)]
pub update: SessionUpdate,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "update_type", rename_all = "snake_case")]
pub enum SessionUpdate {
MessageDelta {
delta: String,
},
ToolCallStart {
tool_call: ToolCallRecord,
},
ToolCallEnd {
tool_call_id: String,
result: Value,
},
TurnComplete {
status: TurnStatus,
},
Error {
code: String,
message: String,
},
ServerRequest {
request: ToolExecutionRequest,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkspaceContext {
pub root_path: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub active_files: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub id: String,
pub name: String,
pub arguments: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
pub timestamp: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolExecutionRequest {
pub request_id: String,
pub tool_call: ToolCallRecord,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolExecutionResult {
pub request_id: String,
pub tool_call_id: String,
pub output: Value,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerRequestNotification {
pub session_id: String,
pub request: ToolExecutionRequest,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationTurn {
pub turn_id: String,
pub prompt: Vec<PromptContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCallRecord>,
pub timestamp: String,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_session_new_params() {
let params = SessionNewParams::default();
let json = serde_json::to_value(¶ms).unwrap();
assert_eq!(json, json!({}));
}
#[test]
fn test_prompt_content_text() {
let content = PromptContent::text("Hello, world!");
let json = serde_json::to_value(&content).unwrap();
assert_eq!(json["type"], "text");
assert_eq!(json["text"], "Hello, world!");
}
#[test]
fn test_session_update_message_delta() {
let update = SessionUpdate::MessageDelta {
delta: "Hello".to_string(),
};
let json = serde_json::to_value(&update).unwrap();
assert_eq!(json["update_type"], "message_delta");
assert_eq!(json["delta"], "Hello");
}
#[test]
fn test_session_state_transitions() {
let mut session = AcpSession::new("test-session");
assert_eq!(session.state, SessionState::Created);
session.set_state(SessionState::Active);
assert_eq!(session.state, SessionState::Active);
assert!(session.last_activity_at.is_some());
}
#[test]
fn server_request_update_serializes_correctly() {
let tool_call = ToolCallRecord {
id: "tc-1".to_string(),
name: "unified_search".to_string(),
arguments: json!({"query": "fn main"}),
result: None,
timestamp: "2025-01-01T00:00:00Z".to_string(),
};
let request = ToolExecutionRequest {
request_id: "req-1".to_string(),
tool_call,
};
let update = SessionUpdate::ServerRequest { request };
let json = serde_json::to_value(&update).unwrap();
assert_eq!(json["update_type"], "server_request");
assert_eq!(json["request"]["request_id"], "req-1");
}
#[test]
fn tool_execution_result_success_serializes() {
let result = ToolExecutionResult {
request_id: "req-1".to_string(),
tool_call_id: "tc-1".to_string(),
output: json!({"matches": []}),
success: true,
error: None,
};
let json = serde_json::to_value(&result).unwrap();
assert_eq!(json["success"], true);
assert!(json.get("error").is_none());
}
#[test]
fn tool_execution_result_failure_includes_error() {
let result = ToolExecutionResult {
request_id: "req-1".to_string(),
tool_call_id: "tc-1".to_string(),
output: Value::Null,
success: false,
error: Some("permission denied".to_string()),
};
let json = serde_json::to_value(&result).unwrap();
assert_eq!(json["success"], false);
assert_eq!(json["error"], "permission denied");
}
}