use super::protocol::{ClientMessage, MessageType, ServerMessage};
use crate::config::Config;
use crate::session::cancellation::SessionCancellation;
use crate::session::chat::session::{
execute_api_call_and_process_response, prepare_for_api_call, process_layers_if_enabled,
setup_and_initialize_session, setup_system_prompt_and_cache, ChatSession,
};
use crate::{log_debug, log_error, log_info};
use anyhow::Result;
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
pub struct WebSocketServer {
addr: SocketAddr,
config: Arc<Config>,
role: String,
}
impl WebSocketServer {
pub fn new(host: &str, port: u16, config: Config, role: String) -> Result<Self> {
let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
Ok(Self {
addr,
config: Arc::new(config),
role,
})
}
pub async fn start(&self) -> Result<()> {
let listener = TcpListener::bind(&self.addr).await?;
log_info!("WebSocket server listening on ws://{}", self.addr);
println!("🚀 WebSocket server started on ws://{}", self.addr);
println!("Press Ctrl+C to stop the server");
let sessions: Arc<Mutex<HashMap<String, ChatSession>>> =
Arc::new(Mutex::new(HashMap::new()));
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
log_info!("Connection accepted from {}", peer_addr);
if let Err(e) = handle_connection(
stream,
peer_addr,
Arc::clone(&self.config),
self.role.clone(),
Arc::clone(&sessions),
)
.await
{
log_error!("Connection handler failed for {}: {}", peer_addr, e);
}
}
Err(e) => {
log_error!("Failed to accept connection: {}", e);
}
}
}
}
}
async fn handle_connection(
stream: TcpStream,
peer_addr: SocketAddr,
config: Arc<Config>,
role: String,
sessions: Arc<Mutex<HashMap<String, ChatSession>>>,
) -> Result<()> {
let ws_config = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
.max_message_size(Some(10 * 1024 * 1024)) .max_frame_size(Some(10 * 1024 * 1024)) .accept_unmasked_frames(false);
let ws_stream = tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await?;
log_info!("WebSocket handshake completed for {}", peer_addr);
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
let welcome = ServerMessage::status(
format!("Connected to Octomind WebSocket server (role: {})", role),
None,
);
send_message(&mut ws_sender, &welcome).await?;
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Text(text)) => {
log_debug!("Received message from {}: {} bytes", peer_addr, text.len());
let client_msg = match serde_json::from_str::<ClientMessage>(&text) {
Ok(msg) => {
log_debug!(
"Parsed message: session_id={:?}, content_len={}",
msg.session_id,
msg.content.len()
);
msg
}
Err(e) => {
log_error!("Invalid JSON from {}: {}", peer_addr, e);
let error = ServerMessage::error(format!("Invalid JSON: {}", e));
send_message(&mut ws_sender, &error).await?;
continue;
}
};
if let Err(e) = client_msg.validate() {
log_error!("Message validation failed from {}: {}", peer_addr, e);
let error = ServerMessage::error(e);
send_message(&mut ws_sender, &error).await?;
continue;
}
if let Err(e) =
process_client_message(client_msg, &mut ws_sender, &config, &role, &sessions)
.await
{
log_error!("Message processing failed for {}: {}", peer_addr, e);
let error = ServerMessage::error(format!("Internal error: {}", e));
send_message(&mut ws_sender, &error).await?;
}
}
Ok(Message::Close(_)) => {
log_info!("Client {} closed connection", peer_addr);
break;
}
Ok(Message::Ping(data)) => {
log_debug!("Ping received from {}", peer_addr);
if let Err(e) = ws_sender.send(Message::Pong(data)).await {
log_error!("Failed to send pong to {}: {}", peer_addr, e);
break;
}
}
Ok(_) => {
}
Err(e) => {
log_error!("WebSocket protocol error from {}: {}", peer_addr, e);
break;
}
}
}
log_info!("Connection closed: {}", peer_addr);
Ok(())
}
async fn process_client_message(
client_msg: ClientMessage,
ws_sender: &mut futures_util::stream::SplitSink<
WebSocketStream<TcpStream>,
tokio_tungstenite::tungstenite::Message,
>,
config: &Config,
role: &str,
sessions: &Arc<Mutex<HashMap<String, ChatSession>>>,
) -> Result<()> {
log_debug!(
"Processing message: session_id={:?}, content_len={}",
client_msg.session_id,
client_msg.content.len()
);
let mut sessions_lock = sessions.lock().await;
let (mut chat_session, session_id, is_new_session) = if let Some(session_id) =
&client_msg.session_id
{
if let Some(session) = sessions_lock.remove(session_id) {
drop(sessions_lock); log_debug!("Resumed session from memory: {}", session_id);
(session, session_id.clone(), false)
} else {
drop(sessions_lock);
log_debug!("Loading session from disk: {}", session_id);
#[derive(Debug)]
#[allow(dead_code)]
struct ResumeArgs {
name: Option<String>,
resume: Option<String>,
resume_recent: bool,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
role: String,
max_retries: Option<u32>,
}
let args = ResumeArgs {
name: None,
resume: Some(session_id.clone()),
resume_recent: false,
model: None,
max_tokens: None,
temperature: None,
role: role.to_string(),
max_retries: None,
};
match setup_and_initialize_session(&args, config).await {
Ok((session, config_for_role, session_role, _first_message_processed)) => {
let mut session = session;
setup_system_prompt_and_cache(
&mut session,
&config_for_role,
&session_role,
false,
)
.await?;
log_info!("Session loaded from disk: {}", session_id);
(session, session_id.clone(), false)
}
Err(_) => {
let error = ServerMessage::error(format!(
"Session not found: {}. Please start a new session by omitting session_id.",
session_id
));
send_message(ws_sender, &error).await?;
return Ok(());
}
}
}
} else {
drop(sessions_lock); log_debug!("Creating new session with role: {}", role);
#[derive(Debug)]
#[allow(dead_code)] struct DummyArgs {
name: Option<String>,
resume: Option<String>,
resume_recent: bool,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
role: String,
max_retries: Option<u32>,
}
let args = DummyArgs {
name: None,
resume: None,
resume_recent: false,
model: None,
max_tokens: None,
temperature: None,
role: role.to_string(),
max_retries: None,
};
let (session, config_for_role, session_role, _first_message_processed) =
setup_and_initialize_session(&args, config).await?;
let mut session = session;
setup_system_prompt_and_cache(&mut session, &config_for_role, &session_role, false).await?;
let session_id = session.session.info.name.clone();
log_info!("Session created: {}", session_id);
(session, session_id, true)
};
if is_new_session {
let status = ServerMessage::status(
format!("Session created: {}", session_id),
Some(session_id.clone()),
);
send_message(ws_sender, &status).await?;
}
let current_dir = std::env::current_dir()?;
let input = client_msg.content.clone();
if input.starts_with('/') {
log_debug!(
"Processing command: {}",
input.split_whitespace().next().unwrap_or(&input)
);
let config_for_role = config.get_merged_config_for_role(role);
let mut cancellation = SessionCancellation::new();
let operation_rx = cancellation.new_operation();
let command_result = chat_session
.process_command(&input, &mut config_for_role.clone(), role, operation_rx)
.await?;
use crate::session::chat::session::commands::CommandResult;
match command_result {
CommandResult::Handled => {
log_debug!("Command executed successfully");
let status = ServerMessage::status(
"Command executed successfully".to_string(),
Some(session_id.clone()),
);
send_message(ws_sender, &status).await?;
chat_session.save()?;
sessions
.lock()
.await
.insert(session_id.clone(), chat_session);
return Ok(());
}
CommandResult::HandledWithOutput(command_output) => {
log_debug!("Command executed with structured output");
let response = ServerMessage::with_metadata(
MessageType::Status,
"Command executed successfully".to_string(),
command_output.to_json(),
Some(session_id.clone()),
);
send_message(ws_sender, &response).await?;
chat_session.save()?;
sessions
.lock()
.await
.insert(session_id.clone(), chat_session);
return Ok(());
}
CommandResult::Exit => {
log_info!("Session ended by user command");
let status =
ServerMessage::status("Session ended".to_string(), Some(session_id.clone()));
send_message(ws_sender, &status).await?;
return Ok(());
}
CommandResult::TreatAsUserInput => {
}
}
}
let config_for_role = config.get_merged_config_for_role(role);
let mut cancellation = SessionCancellation::new();
let operation_rx = cancellation.new_operation();
let first_message_processed = !chat_session.session.messages.is_empty();
log_debug!(
"Processing input through layers: first_message={}",
!first_message_processed
);
let (processed_input, layers_modified_session, _layer_cancelled) = process_layers_if_enabled(
&input,
&mut chat_session,
&config_for_role,
role,
first_message_processed,
operation_rx.clone(),
)
.await?;
if !layers_modified_session {
let final_input_with_constraints =
crate::session::chat::session::utils::append_constraints_if_exists(
&processed_input,
&config_for_role.custom_constraints_file_name,
¤t_dir,
);
chat_session.add_user_message(&final_input_with_constraints)?;
}
prepare_for_api_call(&mut chat_session, &config_for_role, operation_rx.clone()).await?;
let messages_before = chat_session.session.messages.len();
log_debug!("Executing API call: messages_before={}", messages_before);
match execute_api_call_and_process_response(
&mut chat_session,
&config_for_role,
role,
operation_rx.clone(),
false, )
.await
{
Ok(_) => {
let current_message_count = chat_session.session.messages.len();
let new_message_count = current_message_count.saturating_sub(messages_before);
log_debug!("API call completed: new_messages={}", new_message_count);
if current_message_count > messages_before {
let new_messages = &chat_session.session.messages[messages_before..];
for msg in new_messages {
match msg.role.as_str() {
"tool" => {
let content = &msg.content;
let actual_content = if let Ok(mcp_result) =
serde_json::from_str::<serde_json::Value>(content)
{
crate::mcp::extract_mcp_content(&mcp_result)
} else {
content.clone()
};
let tool_name = msg
.name
.as_ref()
.map(|n| n.to_string())
.unwrap_or_else(|| "unknown".to_string());
let tool_id = msg
.tool_call_id
.as_ref()
.map(|id| id.to_string())
.unwrap_or_else(|| "unknown".to_string());
let server_name =
crate::session::chat::response::get_tool_server_name_async(
&tool_name, config,
)
.await;
let success = if let Ok(mcp_result) =
serde_json::from_str::<serde_json::Value>(content)
{
!mcp_result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
} else {
true };
let tool_msg = ServerMessage::with_metadata(
MessageType::ToolResult,
actual_content,
json!({
"tool": tool_name,
"tool_id": tool_id,
"server": server_name,
"success": success,
"duration_ms": 0, }),
Some(session_id.clone()),
);
send_message(ws_sender, &tool_msg).await?;
}
"assistant" => {
if let Some(tool_calls_value) = &msg.tool_calls {
if let Some(tool_calls_array) = tool_calls_value.as_array() {
for tool_call in tool_calls_array {
let tool_name = tool_call
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
.or_else(|| {
tool_call.get("name").and_then(|n| n.as_str())
})
.or_else(|| {
tool_call.get("tool_name").and_then(|n| n.as_str())
})
.unwrap_or("unknown");
let tool_id = tool_call
.get("id")
.and_then(|id| id.as_str())
.unwrap_or("unknown");
let server_name =
crate::session::chat::response::get_tool_server_name_async(
tool_name, config,
)
.await;
let tool_params = tool_call
.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| {
if let Some(s) = a.as_str() {
serde_json::from_str::<serde_json::Value>(s)
.ok()
} else {
Some(a.clone())
}
})
.or_else(|| tool_call.get("arguments").cloned())
.or_else(|| tool_call.get("parameters").cloned())
.unwrap_or_else(|| json!({}));
let tool_use_msg = ServerMessage::with_metadata(
MessageType::ToolUse,
format!("Executing: {}(...)", tool_name),
json!({
"tool": tool_name,
"tool_id": tool_id,
"server": server_name,
"params": tool_params,
}),
Some(session_id.clone()),
);
send_message(ws_sender, &tool_use_msg).await?;
}
}
continue;
}
if msg.content.trim().is_empty() {
continue;
}
let assistant_msg = ServerMessage::new(
MessageType::Assistant,
msg.content.clone(),
Some(session_id.clone()),
);
send_message(ws_sender, &assistant_msg).await?;
}
_ => {
}
}
}
}
let total_tokens =
chat_session.session.info.input_tokens + chat_session.session.info.output_tokens;
log_debug!(
"Session stats: tokens={}, cost=${:.4}",
total_tokens,
chat_session.session.info.total_cost
);
let cost_msg = ServerMessage::with_metadata(
MessageType::Cost,
format!(
"Session: {} tokens (${:.4})",
total_tokens, chat_session.session.info.total_cost
),
json!({
"session_tokens": total_tokens,
"session_cost": chat_session.session.info.total_cost,
"input_tokens": chat_session.session.info.input_tokens,
"output_tokens": chat_session.session.info.output_tokens,
"cached_tokens": chat_session.session.info.cached_tokens,
}),
Some(session_id.clone()),
);
send_message(ws_sender, &cost_msg).await?;
}
Err(e) => {
log_error!("API call failed: {}", e);
let error = ServerMessage::error(format!("Error: {}", e));
send_message(ws_sender, &error).await?;
}
}
log_debug!("Saving session: {}", session_id);
chat_session.save()?;
sessions
.lock()
.await
.insert(session_id.clone(), chat_session);
log_debug!("Session stored back in memory: {}", session_id);
Ok(())
}
async fn send_message(
ws_sender: &mut futures_util::stream::SplitSink<
WebSocketStream<TcpStream>,
tokio_tungstenite::tungstenite::Message,
>,
msg: &ServerMessage,
) -> Result<()> {
let json = serde_json::to_string(msg)?;
log_debug!(
"Sending message: type={:?}, size={} bytes",
msg.message_type,
json.len()
);
ws_sender.send(Message::text(json)).await?;
Ok(())
}