Skip to main content

codewhale_app_server/
lib.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, bail};
6use axum::extract::{Request, State};
7use axum::http::{HeaderValue, Method, StatusCode, header};
8use axum::middleware::{self, Next};
9use axum::response::{IntoResponse, Response};
10use axum::routing::{get, post};
11use axum::{Json, Router};
12use codewhale_agent::ModelRegistry;
13use codewhale_config::{CliRuntimeOverrides, ConfigStore};
14use codewhale_core::Runtime;
15use codewhale_hooks::{HookDispatcher, JsonlHookSink, StdoutHookSink, UnixSocketHookSink};
16use codewhale_mcp::McpManager;
17use codewhale_protocol::{
18    AppRequest, AppResponse, PromptRequest, PromptResponse, ThreadRequest, ThreadResponse,
19};
20use codewhale_state::StateStore;
21use codewhale_tools::{ToolCall, ToolRegistry};
22use serde::de::DeserializeOwned;
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
26use tokio::sync::{Mutex, RwLock};
27use tower_http::cors::CorsLayer;
28use uuid::Uuid;
29
30const DEFAULT_CORS_ORIGINS: &[&str] = &[
31    "http://localhost",
32    "http://localhost:1420",
33    "http://localhost:3000",
34    "http://localhost:5173",
35    "http://127.0.0.1",
36    "http://127.0.0.1:1420",
37    "tauri://localhost",
38];
39
40#[derive(Clone)]
41pub struct AppServerOptions {
42    pub listen: SocketAddr,
43    pub config_path: Option<PathBuf>,
44    pub auth_token: Option<String>,
45    pub insecure_no_auth: bool,
46    pub cors_origins: Vec<String>,
47}
48
49impl std::fmt::Debug for AppServerOptions {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("AppServerOptions")
52            .field("listen", &self.listen)
53            .field("config_path", &self.config_path)
54            .field(
55                "auth_token",
56                &self.auth_token.as_ref().map(|_| "<redacted>"),
57            )
58            .field("insecure_no_auth", &self.insecure_no_auth)
59            .field("cors_origins", &self.cors_origins)
60            .finish()
61    }
62}
63
64#[derive(Clone)]
65struct AppState {
66    config_path: Option<PathBuf>,
67    config: Arc<RwLock<codewhale_config::ConfigToml>>,
68    runtime: Arc<Mutex<Runtime>>,
69    registry: ModelRegistry,
70    auth_token: Option<String>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74struct ToolCallRequest {
75    call: ToolCall,
76    #[serde(default)]
77    cwd: Option<PathBuf>,
78}
79
80#[derive(Debug, Deserialize)]
81struct JsonRpcRequest {
82    #[serde(default)]
83    jsonrpc: Option<String>,
84    #[serde(default)]
85    id: Option<Value>,
86    method: String,
87    #[serde(default)]
88    params: Value,
89}
90
91#[derive(Debug)]
92struct JsonRpcError {
93    code: i64,
94    message: String,
95    data: Option<Value>,
96}
97
98#[derive(Debug)]
99struct StdioDispatchResult {
100    result: Value,
101    should_exit: bool,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105enum AppTransport {
106    Http,
107    Stdio,
108}
109
110#[derive(Debug, Deserialize)]
111struct ConfigGetParams {
112    key: String,
113}
114
115#[derive(Debug, Deserialize)]
116struct ConfigSetParams {
117    key: String,
118    value: String,
119}
120
121#[derive(Debug, Deserialize)]
122struct ThreadIdParams {
123    thread_id: String,
124}
125
126#[derive(Debug, Deserialize)]
127struct ThreadMessageParams {
128    thread_id: String,
129    input: String,
130}
131
132pub async fn run(options: AppServerOptions) -> Result<()> {
133    let auth_token = resolve_auth_token(&options)?;
134    let state = build_state(options.config_path.clone(), auth_token)?;
135    let app = app_router(state, &options.cors_origins);
136
137    let listener = tokio::net::TcpListener::bind(options.listen).await?;
138    axum::serve(listener, app).await?;
139    Ok(())
140}
141
142fn app_router(state: AppState, cors_origins: &[String]) -> Router {
143    let protected_routes = Router::new()
144        .route("/thread", post(thread_handler))
145        .route("/app", post(app_handler))
146        .route("/prompt", post(prompt_handler))
147        .route("/tool", post(tool_handler))
148        .route("/jobs", get(jobs_handler))
149        .route("/mcp/startup", post(mcp_startup_handler))
150        .route_layer(middleware::from_fn_with_state(
151            state.clone(),
152            require_app_server_token,
153        ));
154
155    Router::new()
156        .route("/healthz", get(healthz))
157        .merge(protected_routes)
158        .layer(cors_layer(cors_origins))
159        .with_state(state)
160}
161
162pub async fn run_stdio(config_path: Option<PathBuf>) -> Result<()> {
163    let state = build_state(config_path, None)?;
164    let stdin = tokio::io::stdin();
165    let stdout = tokio::io::stdout();
166    let mut reader = BufReader::new(stdin).lines();
167    let mut writer = tokio::io::BufWriter::new(stdout);
168    while let Some(line) = reader.next_line().await? {
169        if line.trim().is_empty() {
170            continue;
171        }
172
173        let request: JsonRpcRequest = match serde_json::from_str(&line) {
174            Ok(value) => value,
175            Err(err) => {
176                let response = jsonrpc_error(
177                    None,
178                    JsonRpcError::parse_error(format!("invalid json: {err}")),
179                );
180                writer.write_all(response.to_string().as_bytes()).await?;
181                writer.write_all(b"\n").await?;
182                writer.flush().await?;
183                continue;
184            }
185        };
186
187        if request
188            .jsonrpc
189            .as_deref()
190            .is_some_and(|version| version != "2.0")
191        {
192            let response = jsonrpc_error(
193                request.id,
194                JsonRpcError::invalid_request("jsonrpc version must be 2.0"),
195            );
196            writer.write_all(response.to_string().as_bytes()).await?;
197            writer.write_all(b"\n").await?;
198            writer.flush().await?;
199            continue;
200        }
201
202        let response = match dispatch_stdio_request(&state, &request.method, request.params).await {
203            Ok(dispatch) => {
204                let encoded = jsonrpc_result(request.id, dispatch.result);
205                writer.write_all(encoded.to_string().as_bytes()).await?;
206                writer.write_all(b"\n").await?;
207                writer.flush().await?;
208                if dispatch.should_exit {
209                    break;
210                }
211                continue;
212            }
213            Err(err) => jsonrpc_error(request.id, err),
214        };
215
216        writer.write_all(response.to_string().as_bytes()).await?;
217        writer.write_all(b"\n").await?;
218        writer.flush().await?;
219    }
220
221    Ok(())
222}
223
224async fn healthz() -> Json<Value> {
225    Json(json!({
226        "status": "ok",
227        "protocol": "v2",
228        "service": "deepseek-app-server"
229    }))
230}
231
232async fn thread_handler(
233    State(state): State<AppState>,
234    Json(req): Json<ThreadRequest>,
235) -> Json<ThreadResponse> {
236    let mut runtime = state.runtime.lock().await;
237    match runtime.handle_thread(req).await {
238        Ok(res) => Json(res),
239        Err(err) => Json(ThreadResponse {
240            thread_id: "error".to_string(),
241            status: format!("error:{err}"),
242            thread: None,
243            threads: Vec::new(),
244            model: None,
245            model_provider: None,
246            cwd: None,
247            approval_policy: None,
248            sandbox: None,
249            events: Vec::new(),
250            data: json!({}),
251        }),
252    }
253}
254
255async fn prompt_handler(
256    State(state): State<AppState>,
257    Json(req): Json<PromptRequest>,
258) -> Json<PromptResponse> {
259    let mut runtime = state.runtime.lock().await;
260    let overrides = CliRuntimeOverrides::default();
261    match runtime.handle_prompt(req, &overrides).await {
262        Ok(res) => Json(res),
263        Err(err) => Json(PromptResponse {
264            output: err.to_string(),
265            model: "unknown".to_string(),
266            events: Vec::new(),
267        }),
268    }
269}
270
271async fn tool_handler(
272    State(state): State<AppState>,
273    Json(req): Json<ToolCallRequest>,
274) -> Json<Value> {
275    let runtime = state.runtime.lock().await;
276    let cwd = req
277        .cwd
278        .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
279    // Resolve approval policy from config instead of hardcoding.
280    let approval_mode = {
281        let cfg = state.config.read().await;
282        cfg.approval_policy
283            .as_deref()
284            .and_then(|p| match p.trim().to_ascii_lowercase().as_str() {
285                "auto" | "yolo" => Some(codewhale_execpolicy::AskForApproval::UnlessTrusted),
286                "never" | "deny" => Some(codewhale_execpolicy::AskForApproval::Never),
287                _ => None,
288            })
289            .unwrap_or(codewhale_execpolicy::AskForApproval::OnRequest)
290    };
291    match runtime.invoke_tool(req.call, approval_mode, &cwd).await {
292        Ok(value) => Json(value),
293        Err(err) => Json(json!({ "ok": false, "error": err.to_string() })),
294    }
295}
296
297async fn jobs_handler(State(state): State<AppState>) -> Json<AppResponse> {
298    let runtime = state.runtime.lock().await;
299    Json(runtime.app_status())
300}
301
302async fn mcp_startup_handler(State(state): State<AppState>) -> Json<Value> {
303    let runtime = state.runtime.lock().await;
304    let summary = runtime.mcp_startup().await;
305    Json(json!({
306        "ok": true,
307        "summary": summary
308    }))
309}
310
311async fn app_handler(
312    State(state): State<AppState>,
313    Json(req): Json<AppRequest>,
314) -> Json<AppResponse> {
315    Json(process_app_request(&state, req, AppTransport::Http).await)
316}
317
318fn build_state(config_path: Option<PathBuf>, auth_token: Option<String>) -> Result<AppState> {
319    let store = ConfigStore::load(config_path.clone())?;
320    let config = store.config.clone();
321    let exec_policy = store.exec_policy_engine();
322    let registry = ModelRegistry::default();
323
324    let state_db_path = config_path
325        .as_ref()
326        .and_then(|p| p.parent().map(|parent| parent.join("state.db")));
327    let state_store = StateStore::open(state_db_path)?;
328
329    let mut hooks = HookDispatcher::default();
330    hooks.add_sink(Arc::new(StdoutHookSink));
331    let hook_log_path = config_path
332        .as_ref()
333        .and_then(|p| p.parent().map(|parent| parent.join("events.jsonl")))
334        .unwrap_or_else(|| PathBuf::from(".deepseek/events.jsonl"));
335    hooks.add_sink(Arc::new(JsonlHookSink::new(hook_log_path)));
336
337    if let Some(socket_path) = config
338        .hook_sinks
339        .as_ref()
340        .and_then(|sinks| sinks.unix_socket_path.as_ref())
341        .filter(|path| !path.as_os_str().is_empty())
342    {
343        hooks.add_sink(Arc::new(UnixSocketHookSink::new(socket_path.clone())));
344    }
345
346    let runtime = Runtime::new(
347        config.clone(),
348        registry.clone(),
349        state_store,
350        Arc::new(ToolRegistry::default()),
351        Arc::new(McpManager::default()),
352        exec_policy,
353        hooks,
354    );
355
356    Ok(AppState {
357        config_path,
358        config: Arc::new(RwLock::new(config)),
359        runtime: Arc::new(Mutex::new(runtime)),
360        registry,
361        auth_token,
362    })
363}
364
365fn resolve_auth_token(options: &AppServerOptions) -> Result<Option<String>> {
366    let configured = options.auth_token.as_ref().map(|token| token.trim());
367    if let Some(token) = configured
368        && token.is_empty()
369    {
370        bail!("app-server auth token cannot be empty");
371    }
372
373    if options.insecure_no_auth {
374        if !options.listen.ip().is_loopback() {
375            bail!("refusing unauthenticated app-server bind on non-loopback address");
376        }
377        eprintln!("warning: app-server HTTP auth disabled by --insecure-no-auth");
378        return Ok(None);
379    }
380
381    let token = configured
382        .map(str::to_string)
383        .unwrap_or_else(|| format!("cwapp_{}", Uuid::new_v4().simple()));
384    if options.auth_token.is_some() {
385        eprintln!("app-server auth: bearer token required for HTTP routes.");
386    } else {
387        eprintln!("app-server auth: generated bearer token for this process.");
388        eprintln!("  Authorization: Bearer {token}");
389        eprintln!("  Pass --auth-token or set CODEWHALE_APP_SERVER_TOKEN for a stable token.");
390    }
391    Ok(Some(token))
392}
393
394fn cors_layer(extra_origins: &[String]) -> CorsLayer {
395    let mut origins: Vec<HeaderValue> = DEFAULT_CORS_ORIGINS
396        .iter()
397        .filter_map(|origin| HeaderValue::from_str(origin).ok())
398        .collect();
399    for raw in extra_origins {
400        let trimmed = raw.trim();
401        if trimmed.is_empty() {
402            continue;
403        }
404        match HeaderValue::from_str(trimmed) {
405            Ok(value) if !origins.contains(&value) => origins.push(value),
406            Ok(_) => {}
407            Err(err) => {
408                eprintln!("warning: ignoring invalid app-server CORS origin `{trimmed}`: {err}")
409            }
410        }
411    }
412
413    CorsLayer::new()
414        .allow_origin(origins)
415        .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
416        .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
417}
418
419async fn require_app_server_token(
420    State(state): State<AppState>,
421    req: Request,
422    next: Next,
423) -> Response {
424    let Some(expected) = state.auth_token.as_deref() else {
425        return next.run(req).await;
426    };
427    let authorized = req
428        .headers()
429        .get(header::AUTHORIZATION)
430        .and_then(|value| value.to_str().ok())
431        .and_then(|raw| raw.strip_prefix("Bearer "))
432        .is_some_and(|token| token == expected);
433
434    if authorized {
435        next.run(req).await
436    } else {
437        (
438            StatusCode::UNAUTHORIZED,
439            Json(json!({
440                "error": {
441                    "message": "app-server bearer token required",
442                    "status": StatusCode::UNAUTHORIZED.as_u16(),
443                }
444            })),
445        )
446            .into_response()
447    }
448}
449
450fn params_or_object(params: Value) -> Value {
451    if params.is_null() { json!({}) } else { params }
452}
453
454fn parse_params<T: DeserializeOwned>(params: Value) -> std::result::Result<T, JsonRpcError> {
455    serde_json::from_value(params).map_err(|err| JsonRpcError::invalid_params(err.to_string()))
456}
457
458fn jsonrpc_result(id: Option<Value>, result: Value) -> Value {
459    json!({
460        "jsonrpc": "2.0",
461        "id": id.unwrap_or(Value::Null),
462        "result": result
463    })
464}
465
466fn jsonrpc_error(id: Option<Value>, err: JsonRpcError) -> Value {
467    json!({
468        "jsonrpc": "2.0",
469        "id": id.unwrap_or(Value::Null),
470        "error": {
471            "code": err.code,
472            "message": err.message,
473            "data": err.data
474        }
475    })
476}
477
478impl JsonRpcError {
479    fn parse_error(message: impl Into<String>) -> Self {
480        Self {
481            code: -32700,
482            message: message.into(),
483            data: None,
484        }
485    }
486
487    fn invalid_request(message: impl Into<String>) -> Self {
488        Self {
489            code: -32600,
490            message: message.into(),
491            data: None,
492        }
493    }
494
495    fn method_not_found(method: &str) -> Self {
496        Self {
497            code: -32601,
498            message: format!("unsupported method: {method}"),
499            data: None,
500        }
501    }
502
503    fn invalid_params(message: impl Into<String>) -> Self {
504        Self {
505            code: -32602,
506            message: message.into(),
507            data: None,
508        }
509    }
510
511    fn internal(message: impl Into<String>) -> Self {
512        Self {
513            code: -32603,
514            message: message.into(),
515            data: None,
516        }
517    }
518}
519
520async fn handle_thread_request(
521    state: &AppState,
522    req: ThreadRequest,
523) -> std::result::Result<ThreadResponse, JsonRpcError> {
524    let mut runtime = state.runtime.lock().await;
525    runtime
526        .handle_thread(req)
527        .await
528        .map_err(|err| JsonRpcError::internal(err.to_string()))
529}
530
531async fn handle_prompt_request(
532    state: &AppState,
533    req: PromptRequest,
534) -> std::result::Result<PromptResponse, JsonRpcError> {
535    let mut runtime = state.runtime.lock().await;
536    runtime
537        .handle_prompt(req, &CliRuntimeOverrides::default())
538        .await
539        .map_err(|err| JsonRpcError::internal(err.to_string()))
540}
541
542async fn dispatch_stdio_request(
543    state: &AppState,
544    method: &str,
545    params: Value,
546) -> std::result::Result<StdioDispatchResult, JsonRpcError> {
547    let outcome = match method {
548        "healthz" | "app/healthz" => StdioDispatchResult {
549            result: json!({
550                "status": "ok",
551                "service": "deepseek-app-server",
552                "transport": "stdio"
553            }),
554            should_exit: false,
555        },
556        "capabilities" => StdioDispatchResult {
557            result: json!({
558                "transport": "stdio",
559                "families": ["thread/*", "app/*", "prompt/*"],
560                "methods": [
561                    "healthz",
562                    "thread/capabilities",
563                    "thread/request",
564                    "thread/create",
565                    "thread/start",
566                    "thread/resume",
567                    "thread/fork",
568                    "thread/list",
569                    "thread/read",
570                    "thread/set_name",
571                    "thread/archive",
572                    "thread/unarchive",
573                    "thread/message",
574                    "app/capabilities",
575                    "app/request",
576                    "app/config/get",
577                    "app/config/set",
578                    "app/config/unset",
579                    "app/config/list",
580                    "app/models",
581                    "app/thread_loaded_list",
582                    "prompt/capabilities",
583                    "prompt/request",
584                    "prompt/run",
585                    "shutdown"
586                ]
587            }),
588            should_exit: false,
589        },
590        "thread/capabilities" => StdioDispatchResult {
591            result: json!({
592                "methods": [
593                    "thread/request",
594                    "thread/create",
595                    "thread/start",
596                    "thread/resume",
597                    "thread/fork",
598                    "thread/list",
599                    "thread/read",
600                    "thread/set_name",
601                    "thread/archive",
602                    "thread/unarchive",
603                    "thread/message"
604                ]
605            }),
606            should_exit: false,
607        },
608        "thread/request" => {
609            let request: ThreadRequest = parse_params(params)?;
610            let response = handle_thread_request(state, request).await?;
611            StdioDispatchResult {
612                result: serde_json::to_value(response)
613                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
614                should_exit: false,
615            }
616        }
617        "thread/create" => {
618            #[derive(Debug, Deserialize)]
619            struct CreateParams {
620                #[serde(default)]
621                metadata: Value,
622            }
623            let parsed: CreateParams = parse_params(params_or_object(params))?;
624            let response = handle_thread_request(
625                state,
626                ThreadRequest::Create {
627                    metadata: parsed.metadata,
628                },
629            )
630            .await?;
631            StdioDispatchResult {
632                result: serde_json::to_value(response)
633                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
634                should_exit: false,
635            }
636        }
637        "thread/start" => {
638            let request = ThreadRequest::Start(parse_params(params_or_object(params))?);
639            let response = handle_thread_request(state, request).await?;
640            StdioDispatchResult {
641                result: serde_json::to_value(response)
642                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
643                should_exit: false,
644            }
645        }
646        "thread/resume" => {
647            let request = ThreadRequest::Resume(parse_params(params_or_object(params))?);
648            let response = handle_thread_request(state, request).await?;
649            StdioDispatchResult {
650                result: serde_json::to_value(response)
651                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
652                should_exit: false,
653            }
654        }
655        "thread/fork" => {
656            let request = ThreadRequest::Fork(parse_params(params_or_object(params))?);
657            let response = handle_thread_request(state, request).await?;
658            StdioDispatchResult {
659                result: serde_json::to_value(response)
660                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
661                should_exit: false,
662            }
663        }
664        "thread/list" => {
665            let request = ThreadRequest::List(parse_params(params_or_object(params))?);
666            let response = handle_thread_request(state, request).await?;
667            StdioDispatchResult {
668                result: serde_json::to_value(response)
669                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
670                should_exit: false,
671            }
672        }
673        "thread/read" => {
674            let request = ThreadRequest::Read(parse_params(params_or_object(params))?);
675            let response = handle_thread_request(state, request).await?;
676            StdioDispatchResult {
677                result: serde_json::to_value(response)
678                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
679                should_exit: false,
680            }
681        }
682        "thread/set_name" | "thread/set-name" => {
683            let request = ThreadRequest::SetName(parse_params(params_or_object(params))?);
684            let response = handle_thread_request(state, request).await?;
685            StdioDispatchResult {
686                result: serde_json::to_value(response)
687                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
688                should_exit: false,
689            }
690        }
691        "thread/archive" => {
692            let parsed: ThreadIdParams = parse_params(params_or_object(params))?;
693            let response = handle_thread_request(
694                state,
695                ThreadRequest::Archive {
696                    thread_id: parsed.thread_id,
697                },
698            )
699            .await?;
700            StdioDispatchResult {
701                result: serde_json::to_value(response)
702                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
703                should_exit: false,
704            }
705        }
706        "thread/unarchive" => {
707            let parsed: ThreadIdParams = parse_params(params_or_object(params))?;
708            let response = handle_thread_request(
709                state,
710                ThreadRequest::Unarchive {
711                    thread_id: parsed.thread_id,
712                },
713            )
714            .await?;
715            StdioDispatchResult {
716                result: serde_json::to_value(response)
717                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
718                should_exit: false,
719            }
720        }
721        "thread/message" => {
722            let parsed: ThreadMessageParams = parse_params(params_or_object(params))?;
723            let response = handle_thread_request(
724                state,
725                ThreadRequest::Message {
726                    thread_id: parsed.thread_id,
727                    input: parsed.input,
728                },
729            )
730            .await?;
731            StdioDispatchResult {
732                result: serde_json::to_value(response)
733                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
734                should_exit: false,
735            }
736        }
737        "app/capabilities" => {
738            let response =
739                process_app_request(state, AppRequest::Capabilities, AppTransport::Stdio).await;
740            StdioDispatchResult {
741                result: serde_json::to_value(response)
742                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
743                should_exit: false,
744            }
745        }
746        "app/request" => {
747            let request: AppRequest = parse_params(params)?;
748            let response = process_app_request(state, request, AppTransport::Stdio).await;
749            StdioDispatchResult {
750                result: serde_json::to_value(response)
751                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
752                should_exit: false,
753            }
754        }
755        "app/config/get" => {
756            let parsed: ConfigGetParams = parse_params(params_or_object(params))?;
757            let response = process_app_request(
758                state,
759                AppRequest::ConfigGet { key: parsed.key },
760                AppTransport::Stdio,
761            )
762            .await;
763            StdioDispatchResult {
764                result: serde_json::to_value(response)
765                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
766                should_exit: false,
767            }
768        }
769        "app/config/set" => {
770            let parsed: ConfigSetParams = parse_params(params_or_object(params))?;
771            let response = process_app_request(
772                state,
773                AppRequest::ConfigSet {
774                    key: parsed.key,
775                    value: parsed.value,
776                },
777                AppTransport::Stdio,
778            )
779            .await;
780            StdioDispatchResult {
781                result: serde_json::to_value(response)
782                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
783                should_exit: false,
784            }
785        }
786        "app/config/unset" => {
787            let parsed: ConfigGetParams = parse_params(params_or_object(params))?;
788            let response = process_app_request(
789                state,
790                AppRequest::ConfigUnset { key: parsed.key },
791                AppTransport::Stdio,
792            )
793            .await;
794            StdioDispatchResult {
795                result: serde_json::to_value(response)
796                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
797                should_exit: false,
798            }
799        }
800        "app/config/list" => {
801            let response =
802                process_app_request(state, AppRequest::ConfigList, AppTransport::Stdio).await;
803            StdioDispatchResult {
804                result: serde_json::to_value(response)
805                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
806                should_exit: false,
807            }
808        }
809        "app/models" => {
810            let response =
811                process_app_request(state, AppRequest::Models, AppTransport::Stdio).await;
812            StdioDispatchResult {
813                result: serde_json::to_value(response)
814                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
815                should_exit: false,
816            }
817        }
818        "app/thread_loaded_list" | "app/thread-loaded-list" => {
819            let response =
820                process_app_request(state, AppRequest::ThreadLoadedList, AppTransport::Stdio).await;
821            StdioDispatchResult {
822                result: serde_json::to_value(response)
823                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
824                should_exit: false,
825            }
826        }
827        "prompt/capabilities" => StdioDispatchResult {
828            result: json!({
829                "methods": ["prompt/request", "prompt/run"]
830            }),
831            should_exit: false,
832        },
833        "prompt/request" | "prompt/run" => {
834            let request: PromptRequest = parse_params(params)?;
835            let response = handle_prompt_request(state, request).await?;
836            StdioDispatchResult {
837                result: serde_json::to_value(response)
838                    .map_err(|err| JsonRpcError::internal(err.to_string()))?,
839                should_exit: false,
840            }
841        }
842        "shutdown" => StdioDispatchResult {
843            result: json!({"ok": true, "status": "stopped"}),
844            should_exit: true,
845        },
846        _ => return Err(JsonRpcError::method_not_found(method)),
847    };
848    Ok(outcome)
849}
850
851async fn process_app_request(
852    state: &AppState,
853    req: AppRequest,
854    transport: AppTransport,
855) -> AppResponse {
856    match req {
857        AppRequest::Capabilities => AppResponse {
858            ok: true,
859            data: json!({
860                "routes": ["/thread", "/app", "/prompt", "/tool", "/jobs", "/mcp/startup"],
861                "config": ["get", "set", "unset", "list"],
862                "events": ["response_start", "response_delta", "response_end", "tool_call_start", "tool_call_result", "mcp_startup_update", "mcp_startup_complete"],
863                "transport": "stdio+http",
864                "config_path": state.config_path.as_ref().map(|p| p.display().to_string()),
865            }),
866            events: Vec::new(),
867        },
868        AppRequest::ConfigGet { key } => {
869            let cfg = state.config.read().await;
870            let value = match transport {
871                AppTransport::Http => cfg.get_display_value(&key),
872                AppTransport::Stdio => cfg.get_value(&key),
873            };
874            AppResponse {
875                ok: true,
876                data: json!({ "key": key, "value": value }),
877                events: Vec::new(),
878            }
879        }
880        AppRequest::ConfigSet { key, value } => {
881            let mut cfg = state.config.write().await;
882            let result = cfg.set_value(&key, &value);
883            let ok = result.is_ok();
884            let message = result.err().map(|e| e.to_string());
885            let snapshot = cfg.clone();
886            drop(cfg);
887            if let Err(e) = persist_config(state, snapshot).await {
888                tracing::error!("Failed to persist config after set: {e}");
889            }
890            AppResponse {
891                ok,
892                data: json!({ "key": key, "value": value, "error": message }),
893                events: Vec::new(),
894            }
895        }
896        AppRequest::ConfigUnset { key } => {
897            let mut cfg = state.config.write().await;
898            let result = cfg.unset_value(&key);
899            let ok = result.is_ok();
900            let message = result.err().map(|e| e.to_string());
901            let snapshot = cfg.clone();
902            drop(cfg);
903            if let Err(e) = persist_config(state, snapshot).await {
904                tracing::error!("Failed to persist config after unset: {e}");
905            }
906            AppResponse {
907                ok,
908                data: json!({ "key": key, "error": message }),
909                events: Vec::new(),
910            }
911        }
912        AppRequest::ConfigList => {
913            let cfg = state.config.read().await;
914            AppResponse {
915                ok: true,
916                data: json!({ "values": cfg.list_values() }),
917                events: Vec::new(),
918            }
919        }
920        AppRequest::Models => AppResponse {
921            ok: true,
922            data: json!({ "models": state.registry.list() }),
923            events: Vec::new(),
924        },
925        AppRequest::ThreadLoadedList => {
926            let mut runtime = state.runtime.lock().await;
927            let response = runtime
928                .handle_thread(codewhale_protocol::ThreadRequest::List(
929                    codewhale_protocol::ThreadListParams {
930                        include_archived: false,
931                        limit: Some(50),
932                    },
933                ))
934                .await;
935            match response {
936                Ok(thread_resp) => AppResponse {
937                    ok: true,
938                    data: json!({ "threads": thread_resp.threads }),
939                    events: thread_resp.events,
940                },
941                Err(err) => AppResponse {
942                    ok: false,
943                    data: json!({ "error": err.to_string() }),
944                    events: Vec::new(),
945                },
946            }
947        }
948    }
949}
950
951async fn persist_config(state: &AppState, config: codewhale_config::ConfigToml) -> Result<()> {
952    if state.config_path.is_none() {
953        return Ok(());
954    }
955    let mut store = ConfigStore::load(state.config_path.clone())?;
956    store.config = config;
957    store.save()
958}
959
960#[cfg(test)]
961mod tests {
962    use super::*;
963    use axum::body::{Body, to_bytes};
964    use codewhale_protocol::AppRequest;
965    use std::fs;
966    use tower::ServiceExt;
967
968    fn app_with_config(auth_token: Option<&str>) -> (Router, tempfile::TempDir) {
969        let tmp = tempfile::tempdir().expect("tempdir");
970        let config_path = tmp.path().join("config.toml");
971        fs::write(&config_path, "api_key = \"sk-deepseek-secret\"\n").expect("write config");
972        let state = build_state(
973            Some(config_path),
974            auth_token.map(std::string::ToString::to_string),
975        )
976        .expect("state");
977        (app_router(state, &[]), tmp)
978    }
979
980    async fn response_body_json(response: Response) -> Value {
981        let bytes = to_bytes(response.into_body(), usize::MAX)
982            .await
983            .expect("body bytes");
984        serde_json::from_slice(&bytes).expect("json response")
985    }
986
987    #[tokio::test]
988    async fn http_app_routes_require_bearer_token_when_auth_enabled() {
989        let (app, _tmp) = app_with_config(Some("test-token"));
990        let response = app
991            .oneshot(
992                Request::builder()
993                    .method(Method::POST)
994                    .uri("/app")
995                    .header(header::CONTENT_TYPE, "application/json")
996                    .body(Body::from(
997                        serde_json::to_vec(&AppRequest::ConfigGet {
998                            key: "api_key".to_string(),
999                        })
1000                        .expect("request json"),
1001                    ))
1002                    .expect("request"),
1003            )
1004            .await
1005            .expect("response");
1006
1007        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1008    }
1009
1010    #[tokio::test]
1011    async fn http_config_get_redacts_sensitive_values_after_auth() {
1012        let (app, _tmp) = app_with_config(Some("test-token"));
1013        let response = app
1014            .oneshot(
1015                Request::builder()
1016                    .method(Method::POST)
1017                    .uri("/app")
1018                    .header(header::AUTHORIZATION, "Bearer test-token")
1019                    .header(header::CONTENT_TYPE, "application/json")
1020                    .body(Body::from(
1021                        serde_json::to_vec(&AppRequest::ConfigGet {
1022                            key: "api_key".to_string(),
1023                        })
1024                        .expect("request json"),
1025                    ))
1026                    .expect("request"),
1027            )
1028            .await
1029            .expect("response");
1030
1031        assert_eq!(response.status(), StatusCode::OK);
1032        let body = response_body_json(response).await;
1033        assert_eq!(body["data"]["value"], "sk-d***cret");
1034    }
1035
1036    #[tokio::test]
1037    async fn cors_does_not_allow_arbitrary_origins() {
1038        let (app, _tmp) = app_with_config(Some("test-token"));
1039        let response = app
1040            .oneshot(
1041                Request::builder()
1042                    .method(Method::GET)
1043                    .uri("/healthz")
1044                    .header(header::ORIGIN, "https://attacker.example")
1045                    .body(Body::empty())
1046                    .expect("request"),
1047            )
1048            .await
1049            .expect("response");
1050
1051        assert_eq!(response.status(), StatusCode::OK);
1052        assert!(
1053            response
1054                .headers()
1055                .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
1056                .is_none()
1057        );
1058    }
1059
1060    #[tokio::test]
1061    async fn build_state_loads_permissions_into_runtime_policy() {
1062        let tmp = tempfile::tempdir().expect("tempdir");
1063        let config_path = tmp.path().join("config.toml");
1064        fs::write(&config_path, "api_key = \"sk-deepseek-secret\"\n").expect("write config");
1065        fs::write(
1066            tmp.path().join("permissions.toml"),
1067            r#"
1068            [[rules]]
1069            tool = "exec_shell"
1070            command = "cargo test"
1071            "#,
1072        )
1073        .expect("write permissions");
1074
1075        let state = build_state(Some(config_path), None).expect("state");
1076        let runtime = state.runtime.lock().await;
1077        let decision = runtime
1078            .exec_policy
1079            .check(codewhale_execpolicy::ExecPolicyContext {
1080                command: "cargo test --workspace",
1081                cwd: "/workspace",
1082                tool: Some("exec_shell"),
1083                path: None,
1084                ask_for_approval: codewhale_execpolicy::AskForApproval::UnlessTrusted,
1085                sandbox_mode: Some("workspace-write"),
1086            })
1087            .expect("policy check");
1088
1089        assert!(decision.allow);
1090        assert!(decision.requires_approval);
1091        assert_eq!(
1092            decision.matched_rule.as_deref(),
1093            Some("tool=exec_shell command=cargo test")
1094        );
1095    }
1096
1097    #[test]
1098    fn non_loopback_bind_without_auth_fails_fast() {
1099        let options = AppServerOptions {
1100            listen: "0.0.0.0:8787".parse().expect("socket addr"),
1101            config_path: None,
1102            auth_token: None,
1103            insecure_no_auth: true,
1104            cors_origins: Vec::new(),
1105        };
1106
1107        let err = resolve_auth_token(&options).expect_err("non-loopback unauth should fail");
1108        assert!(
1109            err.to_string()
1110                .contains("refusing unauthenticated app-server bind")
1111        );
1112    }
1113
1114    #[tokio::test]
1115    async fn stdio_transport_keeps_raw_config_get_for_legacy_clients() {
1116        let tmp = tempfile::tempdir().expect("tempdir");
1117        let config_path = tmp.path().join("config.toml");
1118        fs::write(&config_path, "").expect("write config");
1119        let state = build_state(Some(config_path), None).expect("state");
1120        {
1121            let mut cfg = state.config.write().await;
1122            cfg.api_key = Some("sk-deepseek-secret".to_string());
1123        }
1124
1125        let response = process_app_request(
1126            &state,
1127            AppRequest::ConfigGet {
1128                key: "api_key".to_string(),
1129            },
1130            AppTransport::Stdio,
1131        )
1132        .await;
1133
1134        assert_eq!(response.data["value"], "sk-deepseek-secret");
1135    }
1136
1137    // ── resolve_auth_token ─────────────────────────────────────────────
1138
1139    #[test]
1140    fn auth_token_empty_string_fails() {
1141        let options = AppServerOptions {
1142            listen: "127.0.0.1:0".parse().expect("addr"),
1143            config_path: None,
1144            auth_token: Some("  ".to_string()),
1145            insecure_no_auth: false,
1146            cors_origins: Vec::new(),
1147        };
1148        let err = resolve_auth_token(&options).expect_err("empty token should fail");
1149        assert!(err.to_string().contains("cannot be empty"));
1150    }
1151
1152    #[test]
1153    fn auth_token_generated_when_none_provided() {
1154        let options = AppServerOptions {
1155            listen: "127.0.0.1:0".parse().expect("addr"),
1156            config_path: None,
1157            auth_token: None,
1158            insecure_no_auth: false,
1159            cors_origins: Vec::new(),
1160        };
1161        let token = resolve_auth_token(&options).unwrap();
1162        assert!(token.is_some());
1163        assert!(token.unwrap().starts_with("cwapp_"));
1164    }
1165
1166    #[test]
1167    fn auth_token_explicit_is_preserved() {
1168        let options = AppServerOptions {
1169            listen: "127.0.0.1:0".parse().expect("addr"),
1170            config_path: None,
1171            auth_token: Some("my-secret".to_string()),
1172            insecure_no_auth: false,
1173            cors_origins: Vec::new(),
1174        };
1175        let token = resolve_auth_token(&options).unwrap();
1176        assert_eq!(token.as_deref(), Some("my-secret"));
1177    }
1178
1179    #[test]
1180    fn insecure_no_auth_on_loopback_returns_none() {
1181        let options = AppServerOptions {
1182            listen: "127.0.0.1:0".parse().expect("addr"),
1183            config_path: None,
1184            auth_token: None,
1185            insecure_no_auth: true,
1186            cors_origins: Vec::new(),
1187        };
1188        let token = resolve_auth_token(&options).unwrap();
1189        assert!(token.is_none());
1190    }
1191
1192    // ── cors_layer ─────────────────────────────────────────────────────
1193
1194    #[test]
1195    fn cors_layer_includes_default_origins() {
1196        let layer = cors_layer(&[]);
1197        // Just verify it doesn't panic and creates successfully
1198        let _ = layer;
1199    }
1200
1201    #[test]
1202    fn cors_layer_adds_extra_origins() {
1203        let extras = vec!["https://example.com".to_string()];
1204        let layer = cors_layer(&extras);
1205        let _ = layer;
1206    }
1207
1208    #[test]
1209    fn cors_layer_skips_empty_origins() {
1210        let extras = vec!["".to_string(), "  ".to_string()];
1211        let layer = cors_layer(&extras);
1212        let _ = layer;
1213    }
1214
1215    // ── JsonRpc helpers ────────────────────────────────────────────────
1216
1217    #[test]
1218    fn params_or_object_returns_object_for_null() {
1219        let result = params_or_object(Value::Null);
1220        assert_eq!(result, json!({}));
1221    }
1222
1223    #[test]
1224    fn params_or_object_passthrough_for_non_null() {
1225        let input = json!({"key": "value"});
1226        let result = params_or_object(input.clone());
1227        assert_eq!(result, input);
1228    }
1229
1230    #[test]
1231    fn jsonrpc_result_format() {
1232        let result = jsonrpc_result(Some(json!(1)), json!({"ok": true}));
1233        assert_eq!(result["jsonrpc"], "2.0");
1234        assert_eq!(result["id"], 1);
1235        assert_eq!(result["result"]["ok"], true);
1236    }
1237
1238    #[test]
1239    fn jsonrpc_result_null_id() {
1240        let result = jsonrpc_result(None, json!(null));
1241        assert_eq!(result["id"], Value::Null);
1242    }
1243
1244    #[test]
1245    fn jsonrpc_error_format() {
1246        let err = jsonrpc_error(Some(json!(2)), JsonRpcError::internal("oops"));
1247        assert_eq!(err["jsonrpc"], "2.0");
1248        assert_eq!(err["id"], 2);
1249        assert_eq!(err["error"]["code"], -32603);
1250        assert_eq!(err["error"]["message"], "oops");
1251    }
1252
1253    #[test]
1254    fn jsonrpc_error_codes() {
1255        assert_eq!(JsonRpcError::parse_error("").code, -32700);
1256        assert_eq!(JsonRpcError::invalid_request("").code, -32600);
1257        assert_eq!(JsonRpcError::method_not_found("x").code, -32601);
1258        assert_eq!(JsonRpcError::invalid_params("").code, -32602);
1259        assert_eq!(JsonRpcError::internal("").code, -32603);
1260    }
1261
1262    // ── AppServerOptions ───────────────────────────────────────────────
1263
1264    #[test]
1265    fn app_server_options_debug_does_not_leak_token() {
1266        let options = AppServerOptions {
1267            listen: "127.0.0.1:8080".parse().expect("addr"),
1268            config_path: None,
1269            auth_token: Some("secret-token".to_string()),
1270            insecure_no_auth: false,
1271            cors_origins: vec!["https://example.com".to_string()],
1272        };
1273        let debug = format!("{options:?}");
1274        assert!(!debug.contains("secret-token"));
1275        assert!(debug.contains("<redacted>"));
1276        assert!(debug.contains("8080"));
1277    }
1278
1279    // ── Default CORS origins ──────────────────────────────────────────
1280
1281    #[test]
1282    fn default_cors_origins_include_common_dev_ports() {
1283        assert!(DEFAULT_CORS_ORIGINS.contains(&"http://localhost:3000"));
1284        assert!(DEFAULT_CORS_ORIGINS.contains(&"http://localhost:5173"));
1285        assert!(DEFAULT_CORS_ORIGINS.contains(&"tauri://localhost"));
1286    }
1287}