Skip to main content

localgpt_server/
http.rs

1//! HTTP server for LocalGPT
2//!
3//! Supports multiple sessions with session ID-based routing.
4//! Sessions are created on demand and cached for reuse.
5
6use anyhow::Result;
7use axum::{
8    Router,
9    extract::{
10        Path, Query, Request, State,
11        ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
12    },
13    http::{StatusCode, header},
14    middleware::{self, Next},
15    response::{
16        IntoResponse, Json, Response,
17        sse::{Event, Sse},
18    },
19    routing::{delete, get, post},
20};
21use futures::{SinkExt, StreamExt};
22use rust_embed::RustEmbed;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25use std::collections::HashMap;
26use std::convert::Infallible;
27use std::net::SocketAddr;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use tokio::sync::Mutex;
31use tower_http::cors::{Any, CorsLayer};
32use tower_http::limit::RequestBodyLimitLayer;
33use tracing::{debug, info};
34
35use localgpt_core::agent::{Agent, AgentConfig, StreamEvent, extract_tool_detail};
36use localgpt_core::concurrency::{TurnGate, WorkspaceLock};
37use localgpt_core::config::Config;
38use localgpt_core::heartbeat::{HeartbeatStatus, get_last_heartbeat_event};
39use localgpt_core::memory::MemoryManager;
40
41/// Embedded UI assets
42#[derive(RustEmbed)]
43#[folder = "ui/"]
44struct UiAssets;
45
46/// Session timeout (30 minutes of inactivity)
47const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60);
48
49/// Maximum number of concurrent sessions
50const MAX_SESSIONS: usize = 100;
51
52/// Agent ID for HTTP sessions
53const HTTP_AGENT_ID: &str = "http";
54
55pub struct Server {
56    config: Config,
57    turn_gate: TurnGate,
58    bridge_manager: crate::security::BridgeManager,
59}
60
61pub(crate) struct SessionEntry {
62    agent: Agent,
63    last_accessed: Instant,
64    /// Whether session has unsaved changes
65    dirty: bool,
66}
67
68pub(crate) struct AppState {
69    pub(crate) config: Config,
70    pub(crate) sessions: Mutex<HashMap<String, SessionEntry>>,
71    /// Shared MemoryManager to avoid reinitializing embedding provider
72    pub(crate) memory: MemoryManager,
73    /// In-process turn gate shared with heartbeat runner
74    turn_gate: TurnGate,
75    /// Cross-process workspace lock
76    workspace_lock: WorkspaceLock,
77    /// Per-IP rate limiter
78    rate_limiter: Arc<crate::rate_limiter::RateLimiter>,
79    /// Bridge manager for tracking active connections
80    pub(crate) bridge_manager: crate::security::BridgeManager,
81}
82
83impl Server {
84    pub fn new(config: &Config) -> Result<Self> {
85        Ok(Self {
86            config: config.clone(),
87            turn_gate: TurnGate::new(),
88            bridge_manager: crate::security::BridgeManager::new(),
89        })
90    }
91
92    /// Create a server with a shared TurnGate (for daemon mode where
93    /// heartbeat and HTTP share concurrency control).
94    pub fn new_with_gate(config: &Config, turn_gate: TurnGate) -> Result<Self> {
95        Ok(Self {
96            config: config.clone(),
97            turn_gate,
98            bridge_manager: crate::security::BridgeManager::new(),
99        })
100    }
101
102    /// Create a server with both a shared TurnGate and a shared BridgeManager.
103    pub fn new_daemon(
104        config: &Config,
105        turn_gate: TurnGate,
106        bridge_manager: crate::security::BridgeManager,
107    ) -> Result<Self> {
108        Ok(Self {
109            config: config.clone(),
110            turn_gate,
111            bridge_manager,
112        })
113    }
114
115    pub async fn run(&self) -> Result<()> {
116        // Create shared MemoryManager once to avoid reinitializing embedding provider
117        let memory =
118            MemoryManager::new_with_full_config(&self.config.memory, Some(&self.config), "main")?;
119
120        let workspace_lock = WorkspaceLock::new()?;
121        let rate_limiter = crate::rate_limiter::create_rate_limiter(&self.config.server.rate_limit);
122
123        let state = Arc::new(AppState {
124            config: self.config.clone(),
125            sessions: Mutex::new(HashMap::new()),
126            memory,
127            turn_gate: self.turn_gate.clone(),
128            workspace_lock,
129            rate_limiter,
130            bridge_manager: self.bridge_manager.clone(),
131        });
132
133        // Load persisted sessions on startup
134        if let Err(e) = load_persisted_sessions(&state).await {
135            info!("Could not load persisted sessions: {}", e);
136        }
137
138        // Spawn session cleanup task
139        let cleanup_state = state.clone();
140        tokio::spawn(async move {
141            let mut interval = tokio::time::interval(Duration::from_secs(60));
142            loop {
143                interval.tick().await;
144                cleanup_expired_sessions(&cleanup_state).await;
145            }
146        });
147
148        // Spawn session save task (save every 5 minutes)
149        let save_state = state.clone();
150        tokio::spawn(async move {
151            let mut interval = tokio::time::interval(Duration::from_secs(300));
152            loop {
153                interval.tick().await;
154                save_dirty_sessions(&save_state).await;
155            }
156        });
157
158        let cors = if self.config.server.cors_origins.is_empty() {
159            // Remove permissive Any CORS; default to common local development origins
160            let default_origins = vec![
161                "http://localhost:3000"
162                    .parse::<axum::http::HeaderValue>()
163                    .unwrap(),
164                "http://127.0.0.1:3000"
165                    .parse::<axum::http::HeaderValue>()
166                    .unwrap(),
167                "http://localhost:8080"
168                    .parse::<axum::http::HeaderValue>()
169                    .unwrap(),
170                "http://127.0.0.1:8080"
171                    .parse::<axum::http::HeaderValue>()
172                    .unwrap(),
173                "http://localhost:1420"
174                    .parse::<axum::http::HeaderValue>()
175                    .unwrap(),
176            ];
177            CorsLayer::new()
178                .allow_origin(default_origins)
179                .allow_methods(Any)
180                .allow_headers(Any)
181        } else {
182            let origins: Vec<axum::http::HeaderValue> = self
183                .config
184                .server
185                .cors_origins
186                .iter()
187                .filter_map(|o| o.parse().ok())
188                .collect();
189            CorsLayer::new()
190                .allow_origin(origins)
191                .allow_methods(Any)
192                .allow_headers(Any)
193        };
194
195        // Public routes (no auth required)
196        let public_routes = Router::new()
197            .route("/", get(serve_ui_index))
198            .route("/ui/{*path}", get(serve_ui_file))
199            .route("/health", get(health_check))
200            .route("/api/auth/status", get(auth_status));
201
202        // OpenAI-compatible API routes (auth required if token configured)
203        let openai_routes = Router::new()
204            .route(
205                "/v1/chat/completions",
206                post(crate::openai_compat::chat_completions),
207            )
208            .route("/v1/models", get(crate::openai_compat::list_models))
209            .layer(middleware::from_fn_with_state(
210                state.clone(),
211                rate_limit_middleware,
212            ))
213            .layer(middleware::from_fn_with_state(
214                state.clone(),
215                auth_middleware,
216            ));
217
218        // Protected API routes (auth required if token configured)
219        let api_routes = Router::new()
220            .route("/api/sessions", post(create_session))
221            .route("/api/sessions", get(list_sessions))
222            .route("/api/sessions/{session_id}", delete(delete_session))
223            .route("/api/sessions/{session_id}", get(get_session_status))
224            .route(
225                "/api/sessions/{session_id}/messages",
226                get(get_session_messages),
227            )
228            .route("/api/sessions/{session_id}/compact", post(compact_session))
229            .route("/api/sessions/{session_id}/clear", post(clear_session))
230            .route("/api/sessions/{session_id}/model", post(set_session_model))
231            .route("/api/chat", post(chat))
232            .route("/api/chat/stream", post(chat_stream))
233            .route("/api/ws", get(websocket_handler))
234            .route("/api/memory/search", get(memory_search))
235            .route("/api/memory/stats", get(memory_stats))
236            .route("/api/memory/reindex", post(memory_reindex))
237            .route("/api/status", get(status))
238            .route("/api/config", get(get_config))
239            .route("/api/heartbeat/status", get(heartbeat_status))
240            .route("/api/bridges", get(list_bridges))
241            .route("/api/saved-sessions", get(list_saved_sessions))
242            .route("/api/saved-sessions/{session_id}", get(get_saved_session))
243            .route("/api/logs/daemon", get(get_daemon_logs))
244            .layer(middleware::from_fn_with_state(
245                state.clone(),
246                rate_limit_middleware,
247            ))
248            .layer(middleware::from_fn_with_state(
249                state.clone(),
250                auth_middleware,
251            ));
252
253        let app = public_routes
254            .merge(api_routes)
255            .merge(openai_routes)
256            .layer(RequestBodyLimitLayer::new(
257                self.config.server.max_request_body,
258            ))
259            .layer(cors)
260            .with_state(state);
261
262        let addr: SocketAddr =
263            format!("{}:{}", self.config.server.bind, self.config.server.port).parse()?;
264
265        info!("Starting HTTP server on http://{}", addr);
266
267        let listener = tokio::net::TcpListener::bind(addr).await?;
268        axum::serve(listener, app).await?;
269
270        Ok(())
271    }
272}
273
274// Error response type
275struct AppError(StatusCode, String);
276
277impl IntoResponse for AppError {
278    fn into_response(self) -> Response {
279        (self.0, self.1).into_response()
280    }
281}
282
283/// Constant-time byte comparison to prevent timing attacks on auth tokens.
284fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
285    if a.len() != b.len() {
286        return false;
287    }
288    a.iter()
289        .zip(b.iter())
290        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
291        == 0
292}
293
294// Auth middleware for API routes
295async fn auth_middleware(
296    State(state): State<Arc<AppState>>,
297    request: Request,
298    next: Next,
299) -> Result<Response, StatusCode> {
300    // If no token configured, pass through (backward compat)
301    let Some(expected) = &state.config.server.auth_token else {
302        return Ok(next.run(request).await);
303    };
304
305    let auth_header = request
306        .headers()
307        .get("authorization")
308        .and_then(|v| v.to_str().ok());
309
310    match auth_header {
311        Some(h) if h.starts_with("Bearer ") => {
312            let token = &h[7..];
313            // Use constant-time comparison to prevent timing attacks
314            if constant_time_eq(token.as_bytes(), expected.as_bytes()) {
315                Ok(next.run(request).await)
316            } else {
317                debug!("Auth failed: invalid token");
318                Err(StatusCode::UNAUTHORIZED)
319            }
320        }
321        _ => {
322            debug!("Auth failed: missing or invalid Authorization header");
323            Err(StatusCode::UNAUTHORIZED)
324        }
325    }
326}
327
328// Rate limit middleware for API routes
329async fn rate_limit_middleware(
330    State(state): State<Arc<AppState>>,
331    request: Request,
332    next: Next,
333) -> Result<Response, Response> {
334    let ip = request
335        .extensions()
336        .get::<axum::extract::ConnectInfo<SocketAddr>>()
337        .map(|ci| ci.0.ip())
338        .unwrap_or_else(|| std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
339
340    if !state.rate_limiter.check(ip).await {
341        return Err((
342            StatusCode::TOO_MANY_REQUESTS,
343            [(header::RETRY_AFTER, "60")],
344            "Rate limit exceeded",
345        )
346            .into_response());
347    }
348
349    Ok(next.run(request).await)
350}
351
352// Auth status endpoint (public, tells client if auth is required)
353async fn auth_status(State(state): State<Arc<AppState>>) -> impl IntoResponse {
354    Json(json!({
355        "auth_required": state.config.server.auth_token.is_some()
356    }))
357}
358
359// Session cleanup task
360async fn cleanup_expired_sessions(state: &Arc<AppState>) {
361    let mut sessions = state.sessions.lock().await;
362    let before_count = sessions.len();
363
364    sessions.retain(|id, entry| {
365        let expired = entry.last_accessed.elapsed() > SESSION_TIMEOUT;
366        if expired {
367            debug!("Expiring session: {}", id);
368        }
369        !expired
370    });
371
372    let removed = before_count - sessions.len();
373    if removed > 0 {
374        info!("Cleaned up {} expired sessions", removed);
375    }
376}
377
378// Load persisted sessions from disk
379async fn load_persisted_sessions(state: &Arc<AppState>) -> Result<(), anyhow::Error> {
380    use localgpt_core::agent::list_sessions_for_agent;
381    use std::sync::Arc as StdArc;
382
383    let sessions_list = list_sessions_for_agent(HTTP_AGENT_ID)?;
384    let mut loaded = 0;
385
386    for session_info in sessions_list.into_iter().take(MAX_SESSIONS) {
387        let agent_config = AgentConfig {
388            model: state.config.agent.default_model.clone(),
389            context_window: state.config.agent.context_window,
390            reserve_tokens: state.config.agent.reserve_tokens,
391        };
392
393        let memory = StdArc::new(state.memory.clone());
394        let mut agent = Agent::new(agent_config, &state.config, memory).await?;
395
396        // Try to resume the session
397        if agent.resume_session(&session_info.id).await.is_ok() {
398            let mut sessions = state.sessions.lock().await;
399            sessions.insert(
400                session_info.id.clone(),
401                SessionEntry {
402                    agent,
403                    last_accessed: Instant::now(),
404                    dirty: false,
405                },
406            );
407            loaded += 1;
408        }
409    }
410
411    if loaded > 0 {
412        info!("Loaded {} persisted HTTP sessions", loaded);
413    }
414
415    Ok(())
416}
417
418// Save dirty sessions to disk
419async fn save_dirty_sessions(state: &Arc<AppState>) {
420    let mut sessions = state.sessions.lock().await;
421    let mut saved = 0;
422
423    for (id, entry) in sessions.iter_mut() {
424        if entry.dirty {
425            if let Err(e) = entry.agent.save_session_for_agent(HTTP_AGENT_ID).await {
426                debug!("Failed to save session {}: {}", id, e);
427            } else {
428                entry.dirty = false;
429                saved += 1;
430            }
431        }
432    }
433
434    if saved > 0 {
435        info!("Saved {} HTTP sessions to disk", saved);
436    }
437}
438
439// Get or create a session
440async fn get_or_create_session(
441    state: &Arc<AppState>,
442    session_id: Option<String>,
443) -> Result<String, AppError> {
444    let mut sessions = state.sessions.lock().await;
445
446    // If session_id provided, try to use existing session
447    if let Some(ref id) = session_id
448        && sessions.contains_key(id)
449    {
450        // Update last accessed time
451        if let Some(entry) = sessions.get_mut(id) {
452            entry.last_accessed = Instant::now();
453        }
454        return Ok(id.clone());
455    }
456
457    // Check session limit
458    if sessions.len() >= MAX_SESSIONS {
459        // Try to remove oldest session
460        if let Some(oldest_id) = sessions
461            .iter()
462            .min_by_key(|(_, e)| e.last_accessed)
463            .map(|(id, _)| id.clone())
464        {
465            sessions.remove(&oldest_id);
466            info!("Removed oldest session {} to make room", oldest_id);
467        }
468    }
469
470    // Create new session
471    let new_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
472
473    let agent_config = AgentConfig {
474        model: state.config.agent.default_model.clone(),
475        context_window: state.config.agent.context_window,
476        reserve_tokens: state.config.agent.reserve_tokens,
477    };
478
479    let memory = std::sync::Arc::new(state.memory.clone());
480    let mut agent = Agent::new(agent_config, &state.config, memory)
481        .await
482        .map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
483
484    agent
485        .new_session()
486        .await
487        .map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
488
489    sessions.insert(
490        new_id.clone(),
491        SessionEntry {
492            agent,
493            last_accessed: Instant::now(),
494            dirty: true, // New sessions should be saved
495        },
496    );
497
498    info!("Created new session: {}", new_id);
499    Ok(new_id)
500}
501
502// Health check endpoint
503async fn health_check() -> &'static str {
504    "OK"
505}
506
507// Serve UI index.html at root
508async fn serve_ui_index() -> Response {
509    serve_ui_asset("index.html")
510}
511
512// Serve UI static files
513async fn serve_ui_file(Path(path): Path<String>) -> Response {
514    serve_ui_asset(&path)
515}
516
517// Helper to serve embedded UI assets
518fn serve_ui_asset(path: &str) -> Response {
519    match UiAssets::get(path) {
520        Some(content) => {
521            let mime = match path.rsplit('.').next() {
522                Some("js") => "application/javascript".to_string(),
523                Some("wasm") => "application/wasm".to_string(),
524                _ => mime_guess::from_path(path)
525                    .first_or_octet_stream()
526                    .to_string(),
527            };
528            ([(header::CONTENT_TYPE, mime)], content.data.to_vec()).into_response()
529        }
530        None => (StatusCode::NOT_FOUND, "Not found").into_response(),
531    }
532}
533
534// Status endpoint
535#[derive(Serialize)]
536struct StatusResponse {
537    version: String,
538    model: String,
539    memory_chunks: usize,
540    active_sessions: usize,
541    is_brand_new: bool,
542}
543
544async fn status(State(state): State<Arc<AppState>>) -> Json<StatusResponse> {
545    let sessions = state.sessions.lock().await;
546
547    Json(StatusResponse {
548        version: env!("CARGO_PKG_VERSION").to_string(),
549        model: state.config.agent.default_model.clone(),
550        memory_chunks: state.memory.chunk_count().unwrap_or(0),
551        active_sessions: sessions.len(),
552        is_brand_new: state.memory.is_brand_new(),
553    })
554}
555
556async fn list_bridges(
557    State(state): State<Arc<AppState>>,
558) -> Json<Vec<crate::security::bridge::BridgeStatus>> {
559    Json(state.bridge_manager.get_active_bridges().await)
560}
561
562// Session management endpoints
563#[derive(Deserialize)]
564struct CreateSessionRequest {
565    session_id: Option<String>,
566}
567
568#[derive(Serialize)]
569struct SessionResponse {
570    session_id: String,
571    model: String,
572}
573
574async fn create_session(
575    State(state): State<Arc<AppState>>,
576    Json(request): Json<CreateSessionRequest>,
577) -> Response {
578    match get_or_create_session(&state, request.session_id).await {
579        Ok(session_id) => Json(SessionResponse {
580            session_id,
581            model: state.config.agent.default_model.clone(),
582        })
583        .into_response(),
584        Err(e) => e.into_response(),
585    }
586}
587
588#[derive(Serialize)]
589struct SessionInfo {
590    session_id: String,
591    idle_seconds: u64,
592}
593
594#[derive(Serialize)]
595struct ListSessionsResponse {
596    sessions: Vec<SessionInfo>,
597}
598
599async fn list_sessions(State(state): State<Arc<AppState>>) -> Json<ListSessionsResponse> {
600    let sessions = state.sessions.lock().await;
601
602    let session_list: Vec<SessionInfo> = sessions
603        .iter()
604        .map(|(id, entry)| SessionInfo {
605            session_id: id.clone(),
606            idle_seconds: entry.last_accessed.elapsed().as_secs(),
607        })
608        .collect();
609
610    Json(ListSessionsResponse {
611        sessions: session_list,
612    })
613}
614
615// Delete a session
616async fn delete_session(
617    State(state): State<Arc<AppState>>,
618    Path(session_id): Path<String>,
619) -> Response {
620    let mut sessions = state.sessions.lock().await;
621
622    if sessions.remove(&session_id).is_some() {
623        info!("Deleted session: {}", session_id);
624        Json(json!({"deleted": true, "session_id": session_id})).into_response()
625    } else {
626        AppError(StatusCode::NOT_FOUND, "Session not found".to_string()).into_response()
627    }
628}
629
630// Get session status
631#[derive(Serialize)]
632struct SessionStatusResponse {
633    session_id: String,
634    model: String,
635    message_count: usize,
636    token_count: usize,
637    idle_seconds: u64,
638    api_input_tokens: u64,
639    api_output_tokens: u64,
640    search_queries: u64,
641    search_cached_hits: u64,
642    search_cost_usd: f64,
643}
644
645async fn get_session_status(
646    State(state): State<Arc<AppState>>,
647    Path(session_id): Path<String>,
648) -> Response {
649    let sessions = state.sessions.lock().await;
650
651    match sessions.get(&session_id) {
652        Some(entry) => {
653            let status = entry.agent.session_status();
654            Json(SessionStatusResponse {
655                session_id,
656                model: entry.agent.model().to_string(),
657                message_count: status.message_count,
658                token_count: status.token_count,
659                idle_seconds: entry.last_accessed.elapsed().as_secs(),
660                api_input_tokens: status.api_input_tokens,
661                api_output_tokens: status.api_output_tokens,
662                search_queries: status.search_queries,
663                search_cached_hits: status.search_cached_hits,
664                search_cost_usd: status.search_cost_usd,
665            })
666            .into_response()
667        }
668        None => AppError(StatusCode::NOT_FOUND, "Session not found".to_string()).into_response(),
669    }
670}
671
672// Get session messages - returns message history for an active session
673#[derive(Serialize)]
674struct ActiveSessionMessage {
675    role: String,
676    content: Option<String>,
677    tool_calls: Option<Vec<serde_json::Value>>,
678    tool_call_id: Option<String>,
679    timestamp: u64,
680}
681
682#[derive(Serialize)]
683struct SessionMessagesResponse {
684    session_id: String,
685    messages: Vec<ActiveSessionMessage>,
686}
687
688async fn get_session_messages(
689    State(state): State<Arc<AppState>>,
690    Path(session_id): Path<String>,
691) -> Response {
692    let mut sessions = state.sessions.lock().await;
693
694    match sessions.get_mut(&session_id) {
695        Some(entry) => {
696            entry.last_accessed = Instant::now();
697
698            let messages: Vec<ActiveSessionMessage> = entry
699                .agent
700                .raw_session_messages()
701                .iter()
702                .map(|sm| {
703                    let role = match sm.message.role {
704                        localgpt_core::agent::Role::User => "user",
705                        localgpt_core::agent::Role::Assistant => "assistant",
706                        localgpt_core::agent::Role::System => "system",
707                        localgpt_core::agent::Role::Tool => "toolResult",
708                    };
709
710                    // Convert tool calls to JSON
711                    let tool_calls = sm.message.tool_calls.as_ref().map(|tcs| {
712                        tcs.iter()
713                            .map(|tc| {
714                                json!({
715                                    "id": tc.id,
716                                    "name": tc.name,
717                                    "arguments": tc.arguments
718                                })
719                            })
720                            .collect()
721                    });
722
723                    ActiveSessionMessage {
724                        role: role.to_string(),
725                        content: if sm.message.content.is_empty() {
726                            None
727                        } else {
728                            Some(sm.message.content.clone())
729                        },
730                        tool_calls,
731                        tool_call_id: sm.message.tool_call_id.clone(),
732                        timestamp: sm.timestamp,
733                    }
734                })
735                .collect();
736
737            Json(SessionMessagesResponse {
738                session_id,
739                messages,
740            })
741            .into_response()
742        }
743        None => AppError(StatusCode::NOT_FOUND, "Session not found".to_string()).into_response(),
744    }
745}
746
747// Compact session history
748async fn compact_session(
749    State(state): State<Arc<AppState>>,
750    Path(session_id): Path<String>,
751) -> Response {
752    let mut sessions = state.sessions.lock().await;
753
754    match sessions.get_mut(&session_id) {
755        Some(entry) => {
756            entry.last_accessed = Instant::now();
757
758            match entry.agent.compact_session().await {
759                Ok((before, after)) => Json(json!({
760                    "session_id": session_id,
761                    "token_count_before": before,
762                    "token_count_after": after,
763                }))
764                .into_response(),
765                Err(e) => {
766                    AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
767                }
768            }
769        }
770        None => AppError(StatusCode::NOT_FOUND, "Session not found".to_string()).into_response(),
771    }
772}
773
774// Clear session history
775async fn clear_session(
776    State(state): State<Arc<AppState>>,
777    Path(session_id): Path<String>,
778) -> Response {
779    let mut sessions = state.sessions.lock().await;
780
781    match sessions.get_mut(&session_id) {
782        Some(entry) => {
783            entry.last_accessed = Instant::now();
784            entry.agent.clear_session();
785            Json(json!({"session_id": session_id, "cleared": true})).into_response()
786        }
787        None => AppError(StatusCode::NOT_FOUND, "Session not found".to_string()).into_response(),
788    }
789}
790
791// Set session model
792#[derive(Deserialize)]
793struct SetModelRequest {
794    model: String,
795}
796
797async fn set_session_model(
798    State(state): State<Arc<AppState>>,
799    Path(session_id): Path<String>,
800    Json(request): Json<SetModelRequest>,
801) -> Response {
802    let mut sessions = state.sessions.lock().await;
803
804    match sessions.get_mut(&session_id) {
805        Some(entry) => {
806            entry.last_accessed = Instant::now();
807
808            match entry.agent.set_model(&request.model) {
809                Ok(()) => Json(json!({
810                    "session_id": session_id,
811                    "model": request.model,
812                }))
813                .into_response(),
814                Err(e) => AppError(StatusCode::BAD_REQUEST, e.to_string()).into_response(),
815            }
816        }
817        None => AppError(StatusCode::NOT_FOUND, "Session not found".to_string()).into_response(),
818    }
819}
820
821// Chat endpoint
822#[derive(Deserialize)]
823struct ChatRequest {
824    message: String,
825    session_id: Option<String>,
826    /// Optional model to use for this request (switches session model)
827    model: Option<String>,
828}
829
830#[derive(Serialize)]
831struct ChatResponse {
832    response: String,
833    session_id: String,
834    model: String,
835}
836
837async fn chat(State(state): State<Arc<AppState>>, Json(request): Json<ChatRequest>) -> Response {
838    // Get or create session
839    let session_id = match get_or_create_session(&state, request.session_id).await {
840        Ok(id) => id,
841        Err(e) => return e.into_response(),
842    };
843
844    // Acquire in-process turn gate (waits for other turns to finish)
845    let _gate_permit = state.turn_gate.acquire().await;
846
847    // Acquire cross-process workspace lock (blocking, so use spawn_blocking)
848    let ws_lock_path = state.workspace_lock.clone();
849    let ws_guard = match tokio::task::spawn_blocking(move || ws_lock_path.acquire()).await {
850        Ok(Ok(guard)) => guard,
851        Ok(Err(e)) => {
852            return AppError(
853                StatusCode::INTERNAL_SERVER_ERROR,
854                format!("Failed to acquire workspace lock: {}", e),
855            )
856            .into_response();
857        }
858        Err(e) => {
859            return AppError(
860                StatusCode::INTERNAL_SERVER_ERROR,
861                format!("Lock task error: {}", e),
862            )
863            .into_response();
864        }
865    };
866
867    // Get agent from session
868    let mut sessions = state.sessions.lock().await;
869    let entry = match sessions.get_mut(&session_id) {
870        Some(e) => e,
871        None => {
872            return AppError(StatusCode::NOT_FOUND, "Session not found".to_string())
873                .into_response();
874        }
875    };
876
877    entry.last_accessed = Instant::now();
878
879    // Switch model if requested
880    if let Some(ref model) = request.model
881        && let Err(e) = entry.agent.set_model(model)
882    {
883        return AppError(StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)).into_response();
884    }
885
886    let result = entry.agent.chat(&request.message).await;
887
888    // Release workspace lock explicitly before returning
889    drop(ws_guard);
890
891    match result {
892        Ok(response) => {
893            entry.dirty = true;
894            Json(ChatResponse {
895                response,
896                session_id,
897                model: entry.agent.model().to_string(),
898            })
899            .into_response()
900        }
901        Err(e) => AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
902    }
903}
904
905// Streaming chat endpoint (SSE) with tool support
906async fn chat_stream(
907    State(state): State<Arc<AppState>>,
908    Json(request): Json<ChatRequest>,
909) -> Response {
910    // Get or create session first (outside the stream)
911    let session_id = match get_or_create_session(&state, request.session_id).await {
912        Ok(id) => id,
913        Err(e) => return e.into_response(),
914    };
915
916    let state_clone = state.clone();
917    let message = request.message.clone();
918
919    let stream = async_stream::stream! {
920        // Send session_id first
921        yield Ok::<Event, Infallible>(Event::default().data(json!({"type": "session", "session_id": session_id}).to_string()));
922
923        // Acquire in-process turn gate
924        let _gate_permit = state_clone.turn_gate.acquire().await;
925
926        // Acquire cross-process workspace lock
927        let ws_lock = state_clone.workspace_lock.clone();
928        let _ws_guard = match tokio::task::spawn_blocking(move || ws_lock.acquire()).await {
929            Ok(Ok(guard)) => Some(guard),
930            Ok(Err(e)) => {
931                yield Ok(Event::default().data(json!({"error": format!("Workspace lock error: {}", e)}).to_string()));
932                return;
933            }
934            Err(e) => {
935                yield Ok(Event::default().data(json!({"error": format!("Lock task error: {}", e)}).to_string()));
936                return;
937            }
938        };
939
940        let mut sessions = state_clone.sessions.lock().await;
941        let entry = match sessions.get_mut(&session_id) {
942            Some(e) => e,
943            None => {
944                yield Ok(Event::default().data(json!({"error": "Session not found"}).to_string()));
945                return;
946            }
947        };
948
949        entry.last_accessed = Instant::now();
950        entry.dirty = true;
951
952        // Use streaming with tools
953        match entry.agent.chat_stream_with_tools(&message, Vec::new()).await {
954            Ok(event_stream) => {
955                use futures::StreamExt;
956
957                // Pin the stream to iterate over it
958                let mut pinned_stream = std::pin::pin!(event_stream);
959
960                while let Some(event) = pinned_stream.next().await {
961                    match event {
962                        Ok(StreamEvent::Content(content)) => {
963                            let data = json!({"type": "content", "delta": content});
964                            yield Ok(Event::default().data(data.to_string()));
965                        }
966                        Ok(StreamEvent::ToolCallStart { name, id, arguments }) => {
967                            let detail = extract_tool_detail(&name, &arguments);
968                            let data = json!({"type": "tool_start", "name": name, "id": id, "detail": detail});
969                            yield Ok(Event::default().data(data.to_string()));
970                        }
971                        Ok(StreamEvent::ToolCallEnd { name, id, output, warnings }) => {
972                            let data = json!({
973                                "type": "tool_end",
974                                "name": name,
975                                "id": id,
976                                "output": output.chars().take(500).collect::<String>(),
977                                "warnings": warnings
978                            });
979                            yield Ok(Event::default().data(data.to_string()));
980                        }
981                        Ok(StreamEvent::Done) => {
982                            let data = json!({"type": "done"});
983                            yield Ok(Event::default().data(data.to_string()));
984                        }
985                        Err(e) => {
986                            yield Ok(Event::default().data(json!({"error": e.to_string()}).to_string()));
987                            break;
988                        }
989                    }
990                }
991            }
992            Err(e) => {
993                yield Ok(Event::default().data(json!({"error": e.to_string()}).to_string()));
994            }
995        }
996
997        yield Ok(Event::default().data("[DONE]"));
998    };
999
1000    Sse::new(stream).into_response()
1001}
1002
1003// Memory search endpoint
1004#[derive(Deserialize)]
1005struct SearchQuery {
1006    q: String,
1007    limit: Option<usize>,
1008}
1009
1010#[derive(Serialize)]
1011struct SearchResult {
1012    file: String,
1013    line_start: i32,
1014    line_end: i32,
1015    content: String,
1016    score: f64,
1017}
1018
1019#[derive(Serialize)]
1020struct SearchResponse {
1021    results: Vec<SearchResult>,
1022    query: String,
1023}
1024
1025async fn memory_search(
1026    State(state): State<Arc<AppState>>,
1027    Query(query): Query<SearchQuery>,
1028) -> Response {
1029    // Reject excessively long queries to prevent DoS
1030    if query.q.len() > 1000 {
1031        return AppError(StatusCode::BAD_REQUEST, "Query too long".to_string()).into_response();
1032    }
1033    match memory_search_inner(&state.memory, &query.q, query.limit) {
1034        Ok(response) => Json(response).into_response(),
1035        Err(e) => AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
1036    }
1037}
1038
1039fn memory_search_inner(
1040    memory: &MemoryManager,
1041    query: &str,
1042    limit: Option<usize>,
1043) -> Result<SearchResponse, anyhow::Error> {
1044    let limit = limit.unwrap_or(10).min(100);
1045    let results = memory.search(query, limit)?;
1046
1047    let results: Vec<SearchResult> = results
1048        .into_iter()
1049        .map(|r| SearchResult {
1050            file: r.file,
1051            line_start: r.line_start,
1052            line_end: r.line_end,
1053            content: r.content,
1054            score: r.score,
1055        })
1056        .collect();
1057
1058    Ok(SearchResponse {
1059        results,
1060        query: query.to_string(),
1061    })
1062}
1063
1064// Memory stats endpoint
1065#[derive(Serialize)]
1066struct StatsResponse {
1067    workspace: String,
1068    total_files: usize,
1069    total_chunks: usize,
1070    index_size_kb: u64,
1071}
1072
1073async fn memory_stats(State(state): State<Arc<AppState>>) -> Response {
1074    match memory_stats_inner(&state.memory) {
1075        Ok(response) => Json(response).into_response(),
1076        Err(e) => AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
1077    }
1078}
1079
1080fn memory_stats_inner(memory: &MemoryManager) -> Result<StatsResponse, anyhow::Error> {
1081    let stats = memory.stats()?;
1082
1083    Ok(StatsResponse {
1084        workspace: stats.workspace,
1085        total_files: stats.total_files,
1086        total_chunks: stats.total_chunks,
1087        index_size_kb: stats.index_size_kb,
1088    })
1089}
1090
1091// Memory reindex endpoint
1092#[derive(Deserialize)]
1093struct ReindexRequest {
1094    #[serde(default)]
1095    force: bool,
1096}
1097
1098#[derive(Serialize)]
1099struct ReindexResponse {
1100    files_processed: usize,
1101    files_updated: usize,
1102    chunks_indexed: usize,
1103    duration_ms: u128,
1104}
1105
1106async fn memory_reindex(
1107    State(state): State<Arc<AppState>>,
1108    Json(request): Json<ReindexRequest>,
1109) -> Response {
1110    // Run reindex in blocking task since it uses sqlite
1111    let memory = state.memory.clone();
1112    let force = request.force;
1113
1114    match tokio::task::spawn_blocking(move || memory_reindex_inner(&memory, force)).await {
1115        Ok(Ok(response)) => Json(response).into_response(),
1116        Ok(Err(e)) => AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
1117        Err(e) => AppError(
1118            StatusCode::INTERNAL_SERVER_ERROR,
1119            format!("Task error: {}", e),
1120        )
1121        .into_response(),
1122    }
1123}
1124
1125fn memory_reindex_inner(
1126    memory: &MemoryManager,
1127    force: bool,
1128) -> Result<ReindexResponse, anyhow::Error> {
1129    let stats = memory.reindex(force)?;
1130
1131    Ok(ReindexResponse {
1132        files_processed: stats.files_processed,
1133        files_updated: stats.files_updated,
1134        chunks_indexed: stats.chunks_indexed,
1135        duration_ms: stats.duration.as_millis(),
1136    })
1137}
1138
1139// Config endpoint - show current configuration (safe subset)
1140#[derive(Serialize)]
1141struct ConfigResponse {
1142    agent: AgentConfigInfo,
1143    server: ServerConfigInfo,
1144    memory: MemoryConfigInfo,
1145    heartbeat: HeartbeatConfigInfo,
1146}
1147
1148#[derive(Serialize)]
1149struct AgentConfigInfo {
1150    default_model: String,
1151    context_window: usize,
1152    reserve_tokens: usize,
1153}
1154
1155#[derive(Serialize)]
1156struct ServerConfigInfo {
1157    port: u16,
1158    bind: String,
1159}
1160
1161#[derive(Serialize)]
1162struct MemoryConfigInfo {
1163    workspace: String,
1164    embedding_model: String,
1165    chunk_size: usize,
1166    chunk_overlap: usize,
1167}
1168
1169#[derive(Serialize)]
1170struct HeartbeatConfigInfo {
1171    enabled: bool,
1172    interval: String,
1173}
1174
1175async fn get_config(State(state): State<Arc<AppState>>) -> Json<ConfigResponse> {
1176    Json(ConfigResponse {
1177        agent: AgentConfigInfo {
1178            default_model: state.config.agent.default_model.clone(),
1179            context_window: state.config.agent.context_window,
1180            reserve_tokens: state.config.agent.reserve_tokens,
1181        },
1182        server: ServerConfigInfo {
1183            port: state.config.server.port,
1184            bind: state.config.server.bind.clone(),
1185        },
1186        memory: MemoryConfigInfo {
1187            workspace: state.config.memory.workspace.clone(),
1188            embedding_model: state.config.memory.embedding_model.clone(),
1189            chunk_size: state.config.memory.chunk_size,
1190            chunk_overlap: state.config.memory.chunk_overlap,
1191        },
1192        heartbeat: HeartbeatConfigInfo {
1193            enabled: state.config.heartbeat.enabled,
1194            interval: state.config.heartbeat.interval.clone(),
1195        },
1196    })
1197}
1198
1199// Heartbeat status endpoint
1200#[derive(Serialize)]
1201struct HeartbeatStatusResponse {
1202    enabled: bool,
1203    interval: String,
1204    last_event: Option<HeartbeatEventInfo>,
1205}
1206
1207#[derive(Serialize)]
1208struct HeartbeatEventInfo {
1209    ts: u64,
1210    status: String,
1211    duration_ms: u64,
1212    preview: Option<String>,
1213    reason: Option<String>,
1214    age_seconds: u64,
1215}
1216
1217async fn heartbeat_status(State(state): State<Arc<AppState>>) -> Json<HeartbeatStatusResponse> {
1218    let last_event = get_last_heartbeat_event().map(|event| {
1219        let now_ms = std::time::SystemTime::now()
1220            .duration_since(std::time::UNIX_EPOCH)
1221            .map(|d| d.as_millis() as u64)
1222            .unwrap_or(0);
1223        let age_seconds = (now_ms.saturating_sub(event.ts)) / 1000;
1224
1225        let status = match event.status {
1226            HeartbeatStatus::Sent => "sent",
1227            HeartbeatStatus::Ok => "ok",
1228            HeartbeatStatus::Skipped => "skipped",
1229            HeartbeatStatus::SkippedMayTry => "skipped",
1230            HeartbeatStatus::Failed => "failed",
1231            HeartbeatStatus::TimedOut => "timed_out",
1232        };
1233
1234        HeartbeatEventInfo {
1235            ts: event.ts,
1236            status: status.to_string(),
1237            duration_ms: event.duration_ms,
1238            preview: event.preview,
1239            reason: event.reason,
1240            age_seconds,
1241        }
1242    });
1243
1244    Json(HeartbeatStatusResponse {
1245        enabled: state.config.heartbeat.enabled,
1246        interval: state.config.heartbeat.interval.clone(),
1247        last_event,
1248    })
1249}
1250
1251// Saved sessions endpoint - list sessions from file store
1252#[derive(Serialize)]
1253struct SavedSessionInfo {
1254    id: String,
1255    message_count: usize,
1256    created_at: String,
1257}
1258
1259#[derive(Serialize)]
1260struct SavedSessionsResponse {
1261    sessions: Vec<SavedSessionInfo>,
1262}
1263
1264async fn list_saved_sessions(State(_state): State<Arc<AppState>>) -> Response {
1265    use localgpt_core::agent::list_sessions_for_agent;
1266
1267    match list_sessions_for_agent(HTTP_AGENT_ID) {
1268        Ok(sessions) => {
1269            let session_list: Vec<SavedSessionInfo> = sessions
1270                .into_iter()
1271                .map(|s| SavedSessionInfo {
1272                    id: s.id,
1273                    message_count: s.message_count,
1274                    created_at: s.created_at.format("%Y-%m-%dT%H:%M:%S").to_string(),
1275                })
1276                .collect();
1277
1278            Json(SavedSessionsResponse {
1279                sessions: session_list,
1280            })
1281            .into_response()
1282        }
1283        Err(e) => AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
1284    }
1285}
1286
1287// Get saved session detail - read and parse JSONL session file
1288#[derive(Serialize)]
1289struct SavedSessionMessage {
1290    role: String,
1291    content: Option<String>,
1292    tool_calls: Option<Vec<serde_json::Value>>,
1293    tool_call_id: Option<String>,
1294    timestamp: Option<u64>,
1295}
1296
1297#[derive(Serialize)]
1298struct SavedSessionDetail {
1299    session_id: String,
1300    created_at: String,
1301    messages: Vec<SavedSessionMessage>,
1302}
1303
1304async fn get_saved_session(Path(session_id): Path<String>) -> Response {
1305    use localgpt_core::agent::get_sessions_dir_for_agent;
1306    use std::fs::File;
1307    use std::io::{BufRead, BufReader};
1308
1309    // Validate session_id uses only safe characters (alphanumeric, hyphens, underscores)
1310    if !session_id
1311        .chars()
1312        .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
1313    {
1314        return AppError(StatusCode::BAD_REQUEST, "Invalid session ID".to_string()).into_response();
1315    }
1316
1317    let sessions_dir = match get_sessions_dir_for_agent(HTTP_AGENT_ID) {
1318        Ok(dir) => dir,
1319        Err(e) => {
1320            return AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response();
1321        }
1322    };
1323
1324    let session_path = sessions_dir.join(format!("{}.jsonl", session_id));
1325
1326    if !session_path.exists() {
1327        return AppError(StatusCode::NOT_FOUND, "Session not found".to_string()).into_response();
1328    }
1329
1330    let file = match File::open(&session_path) {
1331        Ok(f) => f,
1332        Err(e) => {
1333            return AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response();
1334        }
1335    };
1336
1337    let reader = BufReader::new(file);
1338    let mut messages = Vec::new();
1339    let mut created_at = String::new();
1340
1341    for (i, line) in reader.lines().enumerate() {
1342        let line = match line {
1343            Ok(l) => l,
1344            Err(_) => continue,
1345        };
1346
1347        let parsed: serde_json::Value = match serde_json::from_str(&line) {
1348            Ok(v) => v,
1349            Err(_) => continue,
1350        };
1351
1352        // First line is session header
1353        if i == 0 && parsed["type"].as_str() == Some("session") {
1354            created_at = parsed["timestamp"].as_str().unwrap_or("").to_string();
1355            continue;
1356        }
1357
1358        // Parse message entries
1359        if parsed["type"].as_str() == Some("message")
1360            && let Some(msg) = parsed.get("message")
1361        {
1362            let role = msg["role"].as_str().unwrap_or("unknown").to_string();
1363
1364            // Extract text content
1365            let content = if let Some(content_arr) = msg["content"].as_array() {
1366                content_arr
1367                    .iter()
1368                    .filter_map(|c| {
1369                        if c["type"].as_str() == Some("text") {
1370                            c["text"].as_str().map(String::from)
1371                        } else {
1372                            None
1373                        }
1374                    })
1375                    .collect::<Vec<_>>()
1376                    .join("\n")
1377            } else if let Some(text) = msg["content"].as_str() {
1378                text.to_string()
1379            } else {
1380                String::new()
1381            };
1382
1383            // Extract tool calls
1384            let tool_calls = msg["toolCalls"].as_array().cloned();
1385
1386            // Extract tool result ID
1387            let tool_call_id = msg["toolCallId"].as_str().map(String::from);
1388
1389            let timestamp = msg["timestamp"].as_u64();
1390
1391            messages.push(SavedSessionMessage {
1392                role,
1393                content: if content.is_empty() {
1394                    None
1395                } else {
1396                    Some(content)
1397                },
1398                tool_calls,
1399                tool_call_id,
1400                timestamp,
1401            });
1402        }
1403    }
1404
1405    Json(SavedSessionDetail {
1406        session_id,
1407        created_at,
1408        messages,
1409    })
1410    .into_response()
1411}
1412
1413// Daemon logs endpoint - read log file
1414#[derive(Deserialize)]
1415struct LogsQuery {
1416    lines: Option<usize>,
1417}
1418
1419#[derive(Serialize)]
1420struct DaemonLogsResponse {
1421    lines: Vec<String>,
1422    total_lines: usize,
1423    file_size_bytes: u64,
1424}
1425
1426async fn get_daemon_logs(Query(query): Query<LogsQuery>) -> Response {
1427    use localgpt_core::agent::get_state_dir;
1428    use std::fs::File;
1429    use std::io::{BufRead, BufReader};
1430
1431    let lines_requested = query.lines.unwrap_or(200).min(1000);
1432
1433    let state_dir = match get_state_dir() {
1434        Ok(dir) => dir,
1435        Err(e) => {
1436            return AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response();
1437        }
1438    };
1439
1440    // Use date-based log file (matches daemon.rs)
1441    let date = chrono::Local::now().format("%Y-%m-%d");
1442    let log_path = state_dir
1443        .join("logs")
1444        .join(format!("localgpt-{}.log", date));
1445
1446    if !log_path.exists() {
1447        return Json(DaemonLogsResponse {
1448            lines: vec![],
1449            total_lines: 0,
1450            file_size_bytes: 0,
1451        })
1452        .into_response();
1453    }
1454
1455    let metadata = match std::fs::metadata(&log_path) {
1456        Ok(m) => m,
1457        Err(e) => {
1458            return AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response();
1459        }
1460    };
1461
1462    let file = match File::open(&log_path) {
1463        Ok(f) => f,
1464        Err(e) => {
1465            return AppError(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response();
1466        }
1467    };
1468
1469    let reader = BufReader::new(file);
1470    let all_lines: Vec<String> = reader.lines().map_while(Result::ok).collect();
1471    let total_lines = all_lines.len();
1472
1473    // Get last N lines
1474    let lines: Vec<String> = if total_lines > lines_requested {
1475        all_lines[(total_lines - lines_requested)..].to_vec()
1476    } else {
1477        all_lines
1478    };
1479
1480    Json(DaemonLogsResponse {
1481        lines,
1482        total_lines,
1483        file_size_bytes: metadata.len(),
1484    })
1485    .into_response()
1486}
1487
1488// WebSocket handler
1489async fn websocket_handler(
1490    ws: WebSocketUpgrade,
1491    State(state): State<Arc<AppState>>,
1492) -> impl IntoResponse {
1493    // Limit WebSocket message size to 1MB to prevent DoS
1494    ws.max_message_size(1024 * 1024)
1495        .on_upgrade(|socket| handle_websocket(socket, state))
1496}
1497
1498/// WebSocket message types
1499#[derive(Deserialize)]
1500#[serde(tag = "type")]
1501enum WsIncoming {
1502    /// Start or resume a session
1503    #[serde(rename = "session")]
1504    Session { session_id: Option<String> },
1505    /// Chat message (uses tool loop, returns complete response)
1506    /// For streaming, use the SSE endpoint at /api/chat/stream
1507    #[serde(rename = "chat")]
1508    Chat { message: String },
1509    /// Ping for keepalive
1510    #[serde(rename = "ping")]
1511    Ping,
1512}
1513
1514#[derive(Serialize)]
1515#[serde(tag = "type")]
1516#[allow(dead_code)] // ToolStart/ToolEnd reserved for streaming with tools
1517enum WsOutgoing {
1518    /// Connection established
1519    #[serde(rename = "connected")]
1520    Connected { session_id: String },
1521    /// Text content chunk
1522    #[serde(rename = "content")]
1523    Content { delta: String },
1524    /// Tool call started
1525    #[serde(rename = "tool_start")]
1526    ToolStart { name: String, id: String },
1527    /// Tool call completed
1528    #[serde(rename = "tool_end")]
1529    ToolEnd {
1530        name: String,
1531        id: String,
1532        output: String,
1533    },
1534    /// Message complete
1535    #[serde(rename = "done")]
1536    Done,
1537    /// Pong response
1538    #[serde(rename = "pong")]
1539    Pong,
1540    /// Error
1541    #[serde(rename = "error")]
1542    Error { message: String },
1543}
1544
1545async fn handle_websocket(socket: WebSocket, state: Arc<AppState>) {
1546    let (mut sender, mut receiver) = socket.split();
1547
1548    debug!("WebSocket client connected");
1549
1550    // Track current session for this connection
1551    let mut current_session_id: Option<String> = None;
1552
1553    // Process incoming messages
1554    while let Some(msg) = receiver.next().await {
1555        match msg {
1556            Ok(WsMessage::Text(text)) => {
1557                // Parse incoming message
1558                match serde_json::from_str::<WsIncoming>(&text) {
1559                    Ok(WsIncoming::Session { session_id }) => {
1560                        // Create or resume session
1561                        match get_or_create_session(&state, session_id).await {
1562                            Ok(id) => {
1563                                current_session_id = Some(id.clone());
1564                                let connected = WsOutgoing::Connected { session_id: id };
1565                                if let Ok(json) = serde_json::to_string(&connected) {
1566                                    let _ = sender.send(WsMessage::Text(json.into())).await;
1567                                }
1568                            }
1569                            Err(e) => {
1570                                let error = WsOutgoing::Error {
1571                                    message: format!("Failed to create session: {}", e.1),
1572                                };
1573                                if let Ok(json) = serde_json::to_string(&error) {
1574                                    let _ = sender.send(WsMessage::Text(json.into())).await;
1575                                }
1576                            }
1577                        }
1578                    }
1579                    Ok(WsIncoming::Chat { message }) => {
1580                        // Ensure we have a session
1581                        let session_id = match &current_session_id {
1582                            Some(id) => id.clone(),
1583                            None => {
1584                                // Auto-create session if none exists
1585                                match get_or_create_session(&state, None).await {
1586                                    Ok(id) => {
1587                                        current_session_id = Some(id.clone());
1588                                        // Notify client of new session
1589                                        let connected = WsOutgoing::Connected {
1590                                            session_id: id.clone(),
1591                                        };
1592                                        if let Ok(json) = serde_json::to_string(&connected) {
1593                                            let _ = sender.send(WsMessage::Text(json.into())).await;
1594                                        }
1595                                        id
1596                                    }
1597                                    Err(e) => {
1598                                        let error = WsOutgoing::Error {
1599                                            message: format!("Failed to create session: {}", e.1),
1600                                        };
1601                                        if let Ok(json) = serde_json::to_string(&error) {
1602                                            let _ = sender.send(WsMessage::Text(json.into())).await;
1603                                        }
1604                                        continue;
1605                                    }
1606                                }
1607                            }
1608                        };
1609
1610                        debug!("WebSocket chat [{}]: {}", session_id, message);
1611
1612                        // Acquire in-process turn gate
1613                        let _gate_permit = state.turn_gate.acquire().await;
1614
1615                        // Acquire cross-process workspace lock
1616                        let ws_lock = state.workspace_lock.clone();
1617                        let _ws_guard =
1618                            match tokio::task::spawn_blocking(move || ws_lock.acquire()).await {
1619                                Ok(Ok(guard)) => guard,
1620                                Ok(Err(e)) => {
1621                                    let error = WsOutgoing::Error {
1622                                        message: format!("Workspace lock error: {}", e),
1623                                    };
1624                                    if let Ok(json) = serde_json::to_string(&error) {
1625                                        let _ = sender.send(WsMessage::Text(json.into())).await;
1626                                    }
1627                                    continue;
1628                                }
1629                                Err(e) => {
1630                                    let error = WsOutgoing::Error {
1631                                        message: format!("Lock task error: {}", e),
1632                                    };
1633                                    if let Ok(json) = serde_json::to_string(&error) {
1634                                        let _ = sender.send(WsMessage::Text(json.into())).await;
1635                                    }
1636                                    continue;
1637                                }
1638                            };
1639
1640                        // Process chat
1641                        let mut sessions = state.sessions.lock().await;
1642                        let entry = match sessions.get_mut(&session_id) {
1643                            Some(e) => e,
1644                            None => {
1645                                let error = WsOutgoing::Error {
1646                                    message: "Session not found".to_string(),
1647                                };
1648                                if let Ok(json) = serde_json::to_string(&error) {
1649                                    let _ = sender.send(WsMessage::Text(json.into())).await;
1650                                }
1651                                current_session_id = None;
1652                                continue;
1653                            }
1654                        };
1655
1656                        entry.last_accessed = Instant::now();
1657
1658                        match entry.agent.chat(&message).await {
1659                            Ok(response) => {
1660                                // Send response as content
1661                                let content = WsOutgoing::Content { delta: response };
1662                                if let Ok(json) = serde_json::to_string(&content) {
1663                                    let _ = sender.send(WsMessage::Text(json.into())).await;
1664                                }
1665
1666                                // Send done
1667                                let done = WsOutgoing::Done;
1668                                if let Ok(json) = serde_json::to_string(&done) {
1669                                    let _ = sender.send(WsMessage::Text(json.into())).await;
1670                                }
1671                            }
1672                            Err(e) => {
1673                                let error = WsOutgoing::Error {
1674                                    message: e.to_string(),
1675                                };
1676                                if let Ok(json) = serde_json::to_string(&error) {
1677                                    let _ = sender.send(WsMessage::Text(json.into())).await;
1678                                }
1679                            }
1680                        }
1681                    }
1682                    Ok(WsIncoming::Ping) => {
1683                        let pong = WsOutgoing::Pong;
1684                        if let Ok(json) = serde_json::to_string(&pong) {
1685                            let _ = sender.send(WsMessage::Text(json.into())).await;
1686                        }
1687                    }
1688                    Err(e) => {
1689                        let error = WsOutgoing::Error {
1690                            message: format!("Invalid message format: {}", e),
1691                        };
1692                        if let Ok(json) = serde_json::to_string(&error) {
1693                            let _ = sender.send(WsMessage::Text(json.into())).await;
1694                        }
1695                    }
1696                }
1697            }
1698            Ok(WsMessage::Ping(data)) => {
1699                let _ = sender.send(WsMessage::Pong(data)).await;
1700            }
1701            Ok(WsMessage::Close(_)) => {
1702                debug!("WebSocket client disconnected");
1703                break;
1704            }
1705            Err(e) => {
1706                debug!("WebSocket error: {}", e);
1707                break;
1708            }
1709            _ => {}
1710        }
1711    }
1712
1713    debug!("WebSocket connection closed");
1714}