Skip to main content

split_brain_harness/
serve.rs

1/// OpenAI-compatible HTTP proxy server.
2///
3/// Exposes `POST /v1/chat/completions` so any OpenAI-speaking client
4/// (LangChain, Continue.dev, Cursor, custom agents) can route through the
5/// soul-injected telemetry pipeline with zero code changes.
6///
7/// Telemetry is returned two ways:
8///   1. The `content` field carries both the model's answer AND a
9///      `<!-- sbh-telemetry: {...} -->` HTML comment at the end.
10///   2. The `x-sbh-telemetry` response header carries the same JSON, URL-encoded.
11///
12/// Hardening:
13///   - `SBH_SERVE_KEY`      — require Bearer token on all requests
14///   - `SBH_SERVE_RATE`     — max requests/min per IP (default 60)
15///   - `SBH_SERVE_MAX_BODY` — max body bytes (default 1 MiB)
16///
17/// Multi-turn session tracking:
18///   Pass `x-sbh-session: <id>` on requests to link turns into a session.
19///   The response echoes the session ID. If the manipulation_risk signal shows
20///   an upward trend across turns (slow-boil escalation), the response sets
21///   `x-sbh-session-alert: escalation_detected`. Sessions expire after 30
22///   minutes of inactivity (lazy eviction on each request).
23///
24/// Start with: `sbh serve [--listen <addr>]`   default: 127.0.0.1:8088
25use std::collections::{HashMap, VecDeque};
26use std::net::{IpAddr, SocketAddr};
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::sync::{Arc, Mutex};
29use std::time::{Duration, Instant};
30
31use axum::{
32    extract::{ConnectInfo, DefaultBodyLimit, State},
33    http::{HeaderMap, HeaderValue, StatusCode},
34    response::IntoResponse,
35    routing::{get, post},
36    Json, Router,
37};
38use serde::{Deserialize, Serialize};
39
40use anyhow::Context as _;
41use crate::{analyze, session_log, types::Config};
42
43// ---------------------------------------------------------------------------
44// Request / response types (OpenAI wire format subset)
45// ---------------------------------------------------------------------------
46
47#[derive(Debug, Deserialize)]
48pub struct ChatRequest {
49    pub model: Option<String>,
50    pub messages: Vec<ChatMessage>,
51    #[serde(default)]
52    pub stream: bool,
53    // All other fields are accepted and ignored
54    #[serde(flatten)]
55    pub _extra: serde_json::Value,
56}
57
58#[derive(Debug, Deserialize, Serialize, Clone)]
59pub struct ChatMessage {
60    pub role: String,
61    pub content: String,
62}
63
64#[derive(Debug, Serialize)]
65pub struct ChatResponse {
66    pub id: String,
67    pub object: String,
68    pub created: u64,
69    pub model: String,
70    pub choices: Vec<ChatChoice>,
71    pub usage: Usage,
72}
73
74#[derive(Debug, Serialize)]
75pub struct ChatChoice {
76    pub index: u32,
77    pub message: ChatMessage,
78    pub finish_reason: String,
79}
80
81#[derive(Debug, Serialize)]
82pub struct Usage {
83    pub prompt_tokens: u32,
84    pub completion_tokens: u32,
85    pub total_tokens: u32,
86}
87
88#[derive(Debug, Serialize)]
89struct ErrorBody {
90    error: ErrorDetail,
91}
92
93#[derive(Debug, Serialize)]
94struct ErrorDetail {
95    message: String,
96    #[serde(rename = "type")]
97    kind: String,
98}
99
100// ---------------------------------------------------------------------------
101// Session tracking — multi-turn manipulation detection
102// ---------------------------------------------------------------------------
103
104const SESSION_MAX_TURNS: usize = 10;
105const SESSION_TTL: Duration = Duration::from_secs(30 * 60);
106/// Maximum number of concurrent sessions held in memory. New sessions beyond
107/// this cap are refused rather than allowing unbounded HashMap growth.
108const SESSION_MAX_COUNT: usize = 10_000;
109/// Background sweep interval for evicting expired sessions.
110/// The per-request path no longer calls retain() — O(1) instead of O(N).
111const SESSION_SWEEP_INTERVAL: Duration = Duration::from_secs(5 * 60);
112
113// ---------------------------------------------------------------------------
114// Rate limiter — 16-shard sliding window, no extra deps
115// ---------------------------------------------------------------------------
116
117const RATE_LIMITER_SHARDS: usize = 16;
118/// Hard cap on total tracked IPs across all shards. Beyond this, new IPs
119/// are passed through untracked rather than allocating unbounded memory.
120const MAX_TRACKED_IPS: usize = 50_000;
121const MAX_IPS_PER_SHARD: usize = MAX_TRACKED_IPS / RATE_LIMITER_SHARDS;
122
123struct ShardedRateLimiter {
124    shards: Box<[Mutex<HashMap<IpAddr, VecDeque<Instant>>>; RATE_LIMITER_SHARDS]>,
125}
126
127impl ShardedRateLimiter {
128    fn new() -> Self {
129        Self {
130            shards: Box::new(std::array::from_fn(|_| Mutex::new(HashMap::new()))),
131        }
132    }
133
134    fn shard_idx(ip: IpAddr) -> usize {
135        use std::hash::{Hash, Hasher};
136        let mut h = std::collections::hash_map::DefaultHasher::new();
137        ip.hash(&mut h);
138        (h.finish() as usize) % RATE_LIMITER_SHARDS
139    }
140
141    fn check(&self, ip: IpAddr, max_per_minute: u32) -> bool {
142        let idx = Self::shard_idx(ip);
143        let now = Instant::now();
144        let window = Duration::from_secs(60);
145        let mut shard = self.shards[idx].lock().unwrap_or_else(|e| e.into_inner());
146        let is_new = !shard.contains_key(&ip);
147        if is_new && shard.len() >= MAX_IPS_PER_SHARD {
148            // Shard full — try to evict one expired entry first.
149            // If none are expired, pass request through untracked: a sustained
150            // attack filling all shards still hits per-session caps.
151            let expired = shard
152                .iter()
153                .find(|(_, q)| q.back().map_or(true, |&t| now.duration_since(t) > window))
154                .map(|(k, _)| *k);
155            match expired {
156                Some(evict) => { shard.remove(&evict); }
157                None => return true,
158            }
159        }
160        let queue = shard.entry(ip).or_default();
161        while let Some(&front) = queue.front() {
162            if now.duration_since(front) > window {
163                queue.pop_front();
164            } else {
165                break;
166            }
167        }
168        if queue.len() >= max_per_minute as usize {
169            return false;
170        }
171        queue.push_back(now);
172        true
173    }
174}
175
176/// One analyzed turn in a session, recording the risk signals.
177#[derive(Debug, Clone)]
178struct SessionTurn {
179    manipulation_risk: String,
180}
181
182/// Ring buffer of the most recent turns for one session.
183#[derive(Debug)]
184struct SessionHistory {
185    turns: VecDeque<SessionTurn>,
186    last_seen: Instant,
187}
188
189impl SessionHistory {
190    fn new() -> Self {
191        Self {
192            turns: VecDeque::new(),
193            last_seen: Instant::now(),
194        }
195    }
196
197    fn push(&mut self, risk: &str) {
198        let now = Instant::now();
199        self.last_seen = now;
200        if self.turns.len() >= SESSION_MAX_TURNS {
201            self.turns.pop_front();
202        }
203        self.turns.push_back(SessionTurn {
204            manipulation_risk: risk.to_string(),
205        });
206    }
207
208    /// Returns true when the current session shows an upward escalation in
209    /// manipulation_risk compared to the historical mean. Requires ≥3 turns.
210    ///
211    /// Algorithm: map risk to 0/1/2, compute mean of all-but-last turns.
212    /// Escalation fires when the latest turn scores above the historical mean
213    /// by more than 0.5 AND is not "low".
214    fn is_escalating(&self) -> bool {
215        if self.turns.len() < 3 {
216            return false;
217        }
218        let scores: Vec<f64> = self
219            .turns
220            .iter()
221            .map(|t| risk_score(&t.manipulation_risk))
222            .collect();
223        let n = scores.len();
224        let historical_mean: f64 = scores[..n - 1].iter().sum::<f64>() / (n - 1) as f64;
225        let current = scores[n - 1];
226        current > (historical_mean + 0.5) && current >= 1.0
227    }
228
229    fn turn_count(&self) -> usize {
230        self.turns.len()
231    }
232
233    /// Returns (trajectory, historical_mean) — the same values used by
234    /// `is_escalating`, exposed so the caller can write a session log entry.
235    fn risk_summary(&self) -> (Vec<String>, f64) {
236        let trajectory: Vec<String> = self
237            .turns
238            .iter()
239            .map(|t| t.manipulation_risk.clone())
240            .collect();
241        let n = trajectory.len();
242        if n < 2 {
243            return (trajectory, 0.0);
244        }
245        let scores: Vec<f64> = self
246            .turns
247            .iter()
248            .map(|t| risk_score(&t.manipulation_risk))
249            .collect();
250        let historical_mean = scores[..n - 1].iter().sum::<f64>() / (n - 1) as f64;
251        (trajectory, historical_mean)
252    }
253}
254
255fn risk_score(risk: &str) -> f64 {
256    match risk {
257        "high" => 2.0,
258        "medium" => 1.0,
259        _ => 0.0,
260    }
261}
262
263// ---------------------------------------------------------------------------
264// Witness status cache — polled every 30 s by a background task
265// ---------------------------------------------------------------------------
266
267const WITNESS_ACTIVE: u8 = 0;
268const WITNESS_INACTIVE: u8 = 1;
269const WITNESS_UNCONFIGURED: u8 = 2;
270
271fn witness_status_str(v: u8) -> &'static str {
272    match v {
273        WITNESS_ACTIVE => "active",
274        WITNESS_INACTIVE => "inactive",
275        _ => "not-configured",
276    }
277}
278
279/// Spawn a background task that polls `witness status` once at startup and
280/// every 30 seconds thereafter. The result is stored in `cache` (an AtomicU8)
281/// so that the hot request path never blocks on a subprocess.
282///
283/// Only spawned when `audit_path` is Some — otherwise status is fixed to
284/// WITNESS_UNCONFIGURED.
285fn spawn_witness_poller(cache: Arc<std::sync::atomic::AtomicU8>) {
286    tokio::spawn(async move {
287        loop {
288            let result = tokio::process::Command::new("witness")
289                .arg("status")
290                .stdout(std::process::Stdio::null())
291                .stderr(std::process::Stdio::null())
292                .status()
293                .await;
294            let val = match result {
295                Ok(s) if s.success() => WITNESS_ACTIVE,
296                _ => WITNESS_INACTIVE,
297            };
298            cache.store(val, std::sync::atomic::Ordering::Relaxed);
299            tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
300        }
301    });
302}
303
304// ---------------------------------------------------------------------------
305// Metrics — lock-free counters, Prometheus text exposition
306// ---------------------------------------------------------------------------
307
308#[derive(Default)]
309pub struct Metrics {
310    pub requests_total: AtomicU64,
311    pub requests_ok_total: AtomicU64,
312    pub requests_error_total: AtomicU64,
313    pub auth_failures_total: AtomicU64,
314    pub rate_limit_total: AtomicU64,
315    pub escalations_total: AtomicU64,
316}
317
318impl Metrics {
319    fn inc(counter: &AtomicU64) {
320        counter.fetch_add(1, Ordering::Relaxed);
321    }
322
323    pub fn render(&self, active_sessions: usize, uptime_secs: u64) -> String {
324        let mut out = String::with_capacity(512);
325        let pairs: &[(&str, &str, &str, u64)] = &[
326            ("sbh_requests_total",       "counter", "Total POST /v1/chat/completions requests",        self.requests_total.load(Ordering::Relaxed)),
327            ("sbh_requests_ok_total",    "counter", "Requests that returned 200 OK",                   self.requests_ok_total.load(Ordering::Relaxed)),
328            ("sbh_requests_error_total", "counter", "Requests that returned 4xx or 5xx",               self.requests_error_total.load(Ordering::Relaxed)),
329            ("sbh_auth_failures_total",  "counter", "Requests rejected for missing/invalid auth key",  self.auth_failures_total.load(Ordering::Relaxed)),
330            ("sbh_rate_limit_total",     "counter", "Requests rejected by per-IP rate limiter",        self.rate_limit_total.load(Ordering::Relaxed)),
331            ("sbh_escalations_total",    "counter", "Slow-boil session escalation events detected",    self.escalations_total.load(Ordering::Relaxed)),
332            ("sbh_active_sessions",      "gauge",   "Sessions currently held in memory",               active_sessions as u64),
333            ("sbh_uptime_seconds",       "gauge",   "Seconds since sbh serve started",                 uptime_secs),
334        ];
335        for (name, kind, help, value) in pairs {
336            out.push_str(&format!("# HELP {name} {help}\n"));
337            out.push_str(&format!("# TYPE {name} {kind}\n"));
338            out.push_str(&format!("{name} {value}\n"));
339        }
340        out
341    }
342}
343
344// ---------------------------------------------------------------------------
345// Server state
346// ---------------------------------------------------------------------------
347
348#[derive(Clone)]
349pub struct ServeState {
350    config: Arc<Config>,
351    /// Per-IP sliding window — sharded to avoid global lock contention.
352    rate_limiter: Arc<ShardedRateLimiter>,
353    /// Per-session turn history for multi-turn escalation detection.
354    sessions: Arc<Mutex<HashMap<String, SessionHistory>>>,
355    /// Path to append-only session escalation log. Written on every escalation event.
356    session_log_path: Option<String>,
357    /// Prometheus-style counters, shared across handler clones.
358    metrics: Arc<Metrics>,
359    /// Timestamp of server start, used to compute uptime.
360    start_time: Arc<Instant>,
361    /// Cached witness status, refreshed every 30s by a background task.
362    /// "active" | "inactive" | "not-configured"
363    witness_status: Arc<std::sync::atomic::AtomicU8>,
364}
365
366// ---------------------------------------------------------------------------
367// Route handler
368// ---------------------------------------------------------------------------
369
370async fn chat_completions(
371    State(state): State<ServeState>,
372    ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
373    headers: HeaderMap,
374    Json(req): Json<ChatRequest>,
375) -> impl IntoResponse {
376    let config = &*state.config;
377    Metrics::inc(&state.metrics.requests_total);
378
379    // --- serve-level auth (checked before anything else) ---
380    if let Some(sk) = &config.serve_key {
381        let provided = headers
382            .get("authorization")
383            .and_then(|v| v.to_str().ok())
384            .map(|s| s.trim_start_matches("Bearer ").trim().to_string())
385            .unwrap_or_default();
386        if &provided != sk {
387            Metrics::inc(&state.metrics.auth_failures_total);
388            Metrics::inc(&state.metrics.requests_error_total);
389            let body = ErrorBody {
390                error: ErrorDetail {
391                    message: "Unauthorized: invalid or missing SBH serve key.".into(),
392                    kind: "authentication_error".into(),
393                },
394            };
395            return (
396                StatusCode::UNAUTHORIZED,
397                HeaderMap::new(),
398                Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
399            )
400                .into_response();
401        }
402    }
403
404    // --- per-IP rate limit ---
405    let ip = remote_addr.ip();
406    if !state.rate_limiter.check(ip, config.serve_rate_limit) {
407        Metrics::inc(&state.metrics.rate_limit_total);
408        Metrics::inc(&state.metrics.requests_error_total);
409        let body = ErrorBody {
410            error: ErrorDetail {
411                message: format!(
412                    "Rate limit exceeded: max {} requests/min per IP.",
413                    config.serve_rate_limit
414                ),
415                kind: "rate_limit_error".into(),
416            },
417        };
418        return (
419            StatusCode::TOO_MANY_REQUESTS,
420            HeaderMap::new(),
421            Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
422        )
423            .into_response();
424    }
425
426    // --- streaming not supported ---
427    if req.stream {
428        let body = ErrorBody {
429            error: ErrorDetail {
430                message: "sbh serve does not support streaming. Set stream=false.".into(),
431                kind: "unsupported_parameter".into(),
432            },
433        };
434        return (
435            StatusCode::BAD_REQUEST,
436            HeaderMap::new(),
437            Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
438        )
439            .into_response();
440    }
441
442    // --- extract last user message ---
443    let user_input = req
444        .messages
445        .iter()
446        .rev()
447        .find(|m| m.role == "user")
448        .map(|m| m.content.as_str())
449        .unwrap_or("");
450
451    if user_input.is_empty() {
452        let body = ErrorBody {
453            error: ErrorDetail {
454                message: "No user message found in messages array.".into(),
455                kind: "invalid_request_error".into(),
456            },
457        };
458        return (
459            StatusCode::BAD_REQUEST,
460            HeaderMap::new(),
461            Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
462        )
463            .into_response();
464    }
465
466    // --- optionally forward Authorization as upstream API key
467    //     (only when serve_key is NOT set — when serve_key is set, auth is
468    //      used for access control and must not leak to the upstream) ---
469    let mut cfg = (*state.config).clone();
470    if config.serve_key.is_none() {
471        if let Some(auth) = headers.get("authorization") {
472            if let Ok(val) = auth.to_str() {
473                let key = val.trim_start_matches("Bearer ").trim().to_string();
474                if !key.is_empty() {
475                    cfg.api_key = Some(key);
476                }
477            }
478        }
479    }
480
481    // --- session ID: validate client-supplied or mint a cryptographically random one ---
482    let session_id = headers
483        .get("x-sbh-session")
484        .and_then(|v| v.to_str().ok())
485        // Only accept IDs that are safe for HTTP headers and won't enable enumeration
486        .filter(|s| {
487            !s.is_empty()
488                && s.len() <= 64
489                && s.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
490        })
491        .map(|s| s.to_string())
492        .unwrap_or_else(mint_session_id);
493
494    // --- run the full harness pipeline ---
495    let result = match analyze(user_input, &cfg).await {
496        Ok(r) => r,
497        Err(e) => {
498            Metrics::inc(&state.metrics.requests_error_total);
499            let msg = e.to_string();
500            let (status, kind) = if msg.contains("input")
501                || msg.contains("null byte")
502                || msg.contains("too long")
503                || msg.contains("control char")
504            {
505                (StatusCode::BAD_REQUEST, "invalid_request_error")
506            } else {
507                (StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
508            };
509            let body = ErrorBody {
510                error: ErrorDetail {
511                    message: msg,
512                    kind: kind.into(),
513                },
514            };
515            return (
516                status,
517                HeaderMap::new(),
518                Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
519            )
520                .into_response();
521        }
522    };
523
524    // --- session tracking: push turn, check for escalation, evict stale ---
525    let (session_turn_count, session_escalating, session_log_info) = {
526        let mut sessions = state.sessions.lock().unwrap_or_else(|e| e.into_inner());
527        let now = Instant::now();
528        // Lazy TTL: evict only the accessed session if it has expired.
529        // Full map cleanup runs in the background sweeper — no O(N) walk per request.
530        if let Some(h) = sessions.get(&session_id) {
531            if now.duration_since(h.last_seen) >= SESSION_TTL {
532                sessions.remove(&session_id);
533            }
534        }
535        // Refuse new sessions beyond the cap to prevent memory DoS.
536        let is_new = !sessions.contains_key(&session_id);
537        if is_new && sessions.len() >= SESSION_MAX_COUNT {
538            drop(sessions);
539            Metrics::inc(&state.metrics.requests_error_total);
540            let body = ErrorBody {
541                error: ErrorDetail {
542                    message: "session capacity reached — retry later".into(),
543                    kind: "capacity_error".into(),
544                },
545            };
546            return (
547                StatusCode::SERVICE_UNAVAILABLE,
548                HeaderMap::new(),
549                Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({}))),
550            )
551                .into_response();
552        }
553        let hist = sessions.entry(session_id.clone()).or_insert_with(SessionHistory::new);
554        hist.push(&result.telemetry.intent_matrix.manipulation_risk);
555        let escalating = hist.is_escalating();
556        let summary = if escalating {
557            Some(hist.risk_summary())
558        } else {
559            None
560        };
561        (hist.turn_count(), escalating, summary)
562    };
563
564    // --- write session log entry on escalation ---
565    if session_escalating {
566        Metrics::inc(&state.metrics.escalations_total);
567        if let (Some(ref log_path), Some((trajectory, historical_mean))) =
568            (&state.session_log_path, session_log_info)
569        {
570            let entry = session_log::SessionLogEntry::new(
571                session_id.clone(),
572                session_turn_count,
573                trajectory,
574                historical_mean,
575                &ip,
576                user_input,
577            );
578            if let Err(e) = session_log::append(log_path, &entry) {
579                eprintln!("sbh serve: session log write error: {e}");
580            }
581        }
582    }
583
584    // --- build response ---
585    let telemetry_json = serde_json::to_string(&result).unwrap_or_else(|_| "{}".into());
586    let content = format!(
587        "{}\n\n<!-- sbh-telemetry: {} -->",
588        summarize_result(&result),
589        telemetry_json,
590    );
591
592    let model_name = req.model.as_deref().unwrap_or(&cfg.model_name).to_string();
593
594    let response_body = ChatResponse {
595        id: format!("sbh-{}", monotonic_id()),
596        object: "chat.completion".into(),
597        created: unix_now(),
598        model: model_name,
599        choices: vec![ChatChoice {
600            index: 0,
601            message: ChatMessage {
602                role: "assistant".into(),
603                content,
604            },
605            finish_reason: "stop".into(),
606        }],
607        usage: Usage {
608            prompt_tokens: (user_input.len() / 4) as u32,
609            completion_tokens: (telemetry_json.len() / 4) as u32,
610            total_tokens: ((user_input.len() + telemetry_json.len()) / 4) as u32,
611        },
612    };
613
614    let mut resp_headers = HeaderMap::new();
615    if let Ok(encoded) = url_encode(&telemetry_json) {
616        if let Ok(val) = HeaderValue::from_str(&encoded) {
617            resp_headers.insert("x-sbh-telemetry", val);
618        }
619    }
620    resp_headers.insert(
621        "x-sbh-version",
622        HeaderValue::from_static(env!("CARGO_PKG_VERSION")),
623    );
624    // Witness status is refreshed every 30s by a background task — zero blocking here.
625    let witness_status = witness_status_str(
626        state.witness_status.load(std::sync::atomic::Ordering::Relaxed),
627    );
628    if let Ok(val) = HeaderValue::from_str(witness_status) {
629        resp_headers.insert("x-sbh-witness", val);
630    }
631    // Session headers
632    if let Ok(val) = HeaderValue::from_str(&session_id) {
633        resp_headers.insert("x-sbh-session", val);
634    }
635    if let Ok(val) = HeaderValue::from_str(&session_turn_count.to_string()) {
636        resp_headers.insert("x-sbh-session-turns", val);
637    }
638    if session_escalating {
639        resp_headers.insert(
640            "x-sbh-session-alert",
641            HeaderValue::from_static("escalation_detected"),
642        );
643    }
644
645    Metrics::inc(&state.metrics.requests_ok_total);
646    (
647        StatusCode::OK,
648        resp_headers,
649        Json(serde_json::to_value(response_body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
650    )
651        .into_response()
652}
653
654// ---------------------------------------------------------------------------
655// Metrics endpoint — Prometheus text exposition format
656// ---------------------------------------------------------------------------
657
658async fn metrics_handler(
659    State(state): State<ServeState>,
660    headers: HeaderMap,
661) -> impl IntoResponse {
662    // /metrics is protected by the same bearer key as the main endpoint.
663    // Without this, an unauthenticated observer can read request rates,
664    // escalation counts, and active session count.
665    if let Some(sk) = &state.config.serve_key {
666        let provided = headers
667            .get("authorization")
668            .and_then(|v| v.to_str().ok())
669            .map(|s| s.trim_start_matches("Bearer ").trim().to_string())
670            .unwrap_or_default();
671        if &provided != sk {
672            return (
673                StatusCode::UNAUTHORIZED,
674                [("content-type", "text/plain; charset=utf-8")],
675                "Unauthorized".to_string(),
676            );
677        }
678    }
679
680    let active_sessions = state.sessions.lock().unwrap_or_else(|e| e.into_inner()).len();
681    let uptime_secs = state.start_time.elapsed().as_secs();
682    let body = state.metrics.render(active_sessions, uptime_secs);
683    (
684        StatusCode::OK,
685        [("content-type", "text/plain; version=0.0.4; charset=utf-8")],
686        body,
687    )
688}
689
690// ---------------------------------------------------------------------------
691// Health check
692// ---------------------------------------------------------------------------
693
694async fn health() -> impl IntoResponse {
695    Json(serde_json::json!({
696        "status": "ok",
697        "version": env!("CARGO_PKG_VERSION"),
698        "service": "split-brain-harness"
699    }))
700}
701
702// ---------------------------------------------------------------------------
703// Public entry point
704// ---------------------------------------------------------------------------
705
706pub async fn run_server(listen: &str, config: Config, tls_cert: Option<&str>, tls_key: Option<&str>) -> anyhow::Result<()> {
707    let rate_limit = config.serve_rate_limit;
708    let max_body = config.serve_max_body_bytes;
709    let auth_enabled = config.serve_key.is_some();
710    let session_log_path = config.session_log_path.clone();
711    let context_path = config.context_path.clone();
712
713    let witness_cache = Arc::new(std::sync::atomic::AtomicU8::new(WITNESS_UNCONFIGURED));
714    if config.audit_path.is_some() {
715        spawn_witness_poller(Arc::clone(&witness_cache));
716    }
717
718    let sessions: Arc<Mutex<HashMap<String, SessionHistory>>> =
719        Arc::new(Mutex::new(HashMap::new()));
720
721    // Background task: sweep expired sessions every SESSION_SWEEP_INTERVAL.
722    // The hot path no longer calls retain() — this is the only full-map walk.
723    {
724        let sessions_sweep = Arc::clone(&sessions);
725        tokio::spawn(async move {
726            loop {
727                tokio::time::sleep(SESSION_SWEEP_INTERVAL).await;
728                let mut map = sessions_sweep.lock().unwrap_or_else(|e| e.into_inner());
729                let now = Instant::now();
730                map.retain(|_, h| now.duration_since(h.last_seen) < SESSION_TTL);
731            }
732        });
733    }
734
735    let state = ServeState {
736        config: Arc::new(config),
737        rate_limiter: Arc::new(ShardedRateLimiter::new()),
738        sessions,
739        session_log_path: session_log_path.clone(),
740        metrics: Arc::new(Metrics::default()),
741        start_time: Arc::new(Instant::now()),
742        witness_status: witness_cache,
743    };
744
745    let app = Router::new()
746        .route("/v1/chat/completions", post(chat_completions))
747        .route("/health", get(health))
748        .route("/metrics", get(metrics_handler))
749        .layer(DefaultBodyLimit::max(max_body))
750        .with_state(state);
751
752    let print_banner = |scheme: &str, addr: SocketAddr| {
753        eprintln!("sbh serve: listening on {scheme}://{addr}");
754        eprintln!("  POST /v1/chat/completions  — OpenAI-compatible harness proxy");
755        eprintln!("  GET  /health               — liveness check");
756        eprintln!("  GET  /metrics              — Prometheus counters");
757        eprintln!(
758            "  auth: {}  rate: {}/min/IP  max-body: {} bytes",
759            if auth_enabled { "enabled" } else { "disabled" },
760            rate_limit,
761            max_body,
762        );
763        match &session_log_path {
764            Some(p) => eprintln!("  session log: {p}"),
765            None => eprintln!("  session log: disabled (set SBH_SESSION_LOG or --session-log)"),
766        };
767        {
768            use crate::rag::ContextCorpus;
769            let embedded_count = ContextCorpus::embedded().len();
770            match context_path.as_deref() {
771                None => eprintln!("  context: {embedded_count} embedded docs (set SBH_CONTEXT_PATH to add operator docs)"),
772                Some(p) => match ContextCorpus::load(p) {
773                    Ok(extra) => eprintln!("  context: {} embedded + {} operator docs from {p}", embedded_count, extra.len()),
774                    Err(e) => eprintln!("  context: {p} load error — {e}"),
775                },
776            }
777        }
778    };
779
780    match (tls_cert, tls_key) {
781        (Some(cert), Some(key)) => {
782            use axum_server::tls_rustls::RustlsConfig;
783            let tls_config = RustlsConfig::from_pem_file(cert, key)
784                .await
785                .with_context(|| format!("TLS: failed to load cert={cert} key={key}"))?;
786            let addr: SocketAddr = listen.parse()
787                .with_context(|| format!("invalid listen address: {listen}"))?;
788            print_banner("https", addr);
789            axum_server::bind_rustls(addr, tls_config)
790                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
791                .await?;
792        }
793        (Some(_), None) => anyhow::bail!("--tls-cert requires --tls-key"),
794        (None, Some(_)) => anyhow::bail!("--tls-key requires --tls-cert"),
795        (None, None) => {
796            let listener = tokio::net::TcpListener::bind(listen).await?;
797            let addr = listener.local_addr()?;
798            print_banner("http", addr);
799            axum::serve(
800                listener,
801                app.into_make_service_with_connect_info::<SocketAddr>(),
802            )
803            .await?;
804        }
805    }
806    Ok(())
807}
808
809// ---------------------------------------------------------------------------
810// Helpers
811// ---------------------------------------------------------------------------
812
813fn summarize_result(result: &crate::types::HarnessResult) -> String {
814    let t = &result.telemetry;
815    let v = &result.verification;
816    format!(
817        "[SBH Analysis]\nEmotion: {} (intensity {:.2})\nManipulation risk: {}\nCoherence: {:.2}\nVerification: {} (confidence {:.2}){}",
818        t.affective_telemetry.primary_emotion,
819        t.affective_telemetry.emotional_intensity,
820        t.intent_matrix.manipulation_risk,
821        t.cognitive_state.coherence_rating,
822        if v.passed { "passed" } else { "flagged" },
823        v.confidence,
824        if v.stop_and_ask {
825            "\n⚠ stop_and_ask=true — low confidence, review before acting"
826        } else {
827            ""
828        },
829    )
830}
831
832fn unix_now() -> u64 {
833    std::time::SystemTime::now()
834        .duration_since(std::time::UNIX_EPOCH)
835        .map(|d| d.as_secs())
836        .unwrap_or(0)
837}
838
839fn monotonic_id() -> u64 {
840    use std::sync::atomic::{AtomicU64, Ordering};
841    static CTR: AtomicU64 = AtomicU64::new(1);
842    CTR.fetch_add(1, Ordering::Relaxed)
843}
844
845/// Generate a cryptographically random session ID using OS entropy.
846/// Falls back to monotonic counter + timestamp mix if /dev/urandom is unavailable.
847fn mint_session_id() -> String {
848    // Read 16 bytes from /dev/urandom — available on all Linux targets.
849    let mut buf = [0u8; 16];
850    let ok = std::fs::File::open("/dev/urandom")
851        .and_then(|mut f| { use std::io::Read; f.read_exact(&mut buf) })
852        .is_ok();
853    if ok {
854        format!(
855            "sbh-{:08x}{:08x}{:08x}{:08x}",
856            u32::from_le_bytes(buf[0..4].try_into().unwrap()),
857            u32::from_le_bytes(buf[4..8].try_into().unwrap()),
858            u32::from_le_bytes(buf[8..12].try_into().unwrap()),
859            u32::from_le_bytes(buf[12..16].try_into().unwrap()),
860        )
861    } else {
862        format!("sbh-s-{}-{}", monotonic_id(), unix_now())
863    }
864}
865
866/// Percent-encode a string for use in HTTP header values.
867///
868/// Encodes each UTF-8 byte that is not an unreserved ASCII character.
869/// This is correct: we encode bytes, not Unicode codepoints, so multibyte
870/// chars like `é` (UTF-8: 0xC3 0xA9) become `%C3%A9`, not `%E9`.
871fn url_encode(s: &str) -> Result<String, ()> {
872    let mut out = String::with_capacity(s.len());
873    for byte in s.as_bytes() {
874        match byte {
875            // Unreserved ASCII — pass through as-is
876            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9'
877            | b'-' | b'_' | b'.' | b'~' | b':' | b'/' | b',' | b'[' | b']'
878            | b'{' | b'}' => out.push(*byte as char),
879            // Everything else (including %, space, quotes, newlines, high bytes)
880            b => out.push_str(&format!("%{b:02X}")),
881        }
882    }
883    Ok(out)
884}
885
886// ---------------------------------------------------------------------------
887// Tests
888// ---------------------------------------------------------------------------
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893
894    // --- metrics ---
895
896    #[test]
897    fn metrics_render_contains_all_metric_names() {
898        let m = Metrics::default();
899        let out = m.render(0, 0);
900        for name in &[
901            "sbh_requests_total",
902            "sbh_requests_ok_total",
903            "sbh_requests_error_total",
904            "sbh_auth_failures_total",
905            "sbh_rate_limit_total",
906            "sbh_escalations_total",
907            "sbh_active_sessions",
908            "sbh_uptime_seconds",
909        ] {
910            assert!(out.contains(name), "missing metric: {name}");
911        }
912    }
913
914    #[test]
915    fn metrics_render_prometheus_format() {
916        let m = Metrics::default();
917        let out = m.render(3, 42);
918        assert!(out.contains("# HELP sbh_requests_total"));
919        assert!(out.contains("# TYPE sbh_requests_total counter"));
920        assert!(out.contains("sbh_requests_total 0\n"));
921        assert!(out.contains("sbh_active_sessions 3\n"));
922        assert!(out.contains("sbh_uptime_seconds 42\n"));
923    }
924
925    #[test]
926    fn metrics_counters_increment_correctly() {
927        let m = Metrics::default();
928        Metrics::inc(&m.requests_total);
929        Metrics::inc(&m.requests_total);
930        Metrics::inc(&m.escalations_total);
931        let out = m.render(0, 0);
932        assert!(out.contains("sbh_requests_total 2\n"));
933        assert!(out.contains("sbh_escalations_total 1\n"));
934        assert!(out.contains("sbh_requests_ok_total 0\n"));
935    }
936
937    #[test]
938    fn metrics_render_has_help_and_type_for_every_metric() {
939        let m = Metrics::default();
940        let out = m.render(0, 0);
941        let help_count = out.lines().filter(|l| l.starts_with("# HELP")).count();
942        let type_count = out.lines().filter(|l| l.starts_with("# TYPE")).count();
943        assert_eq!(help_count, 8, "expected 8 # HELP lines");
944        assert_eq!(type_count, 8, "expected 8 # TYPE lines");
945    }
946
947    // --- url_encode ---
948
949    #[test]
950    fn url_encode_spaces_and_quotes() {
951        let s = r#"{"key": "val ue"}"#;
952        let encoded = url_encode(s).unwrap();
953        assert!(!encoded.contains(' '));
954        assert!(!encoded.contains('"'));
955        assert!(encoded.contains("%20"));
956        assert!(encoded.contains("%22"));
957    }
958
959    #[test]
960    fn url_encode_clean_string_unchanged() {
961        let s = "hello-world_123";
962        assert_eq!(url_encode(s).unwrap(), s);
963    }
964
965    #[test]
966    fn unix_now_is_nonzero() {
967        assert!(unix_now() > 0);
968    }
969
970    #[test]
971    fn monotonic_id_increases() {
972        let a = monotonic_id();
973        let b = monotonic_id();
974        assert!(b > a);
975    }
976
977    #[test]
978    fn session_no_escalation_below_three_turns() {
979        let mut h = SessionHistory::new();
980        h.push("high");
981        h.push("high");
982        assert!(!h.is_escalating(), "need ≥3 turns before firing");
983    }
984
985    #[test]
986    fn session_escalation_detected_on_slow_boil() {
987        let mut h = SessionHistory::new();
988        h.push("low");
989        h.push("low");
990        h.push("high");
991        assert!(h.is_escalating(), "low→low→high is slow-boil escalation");
992    }
993
994    #[test]
995    fn session_no_escalation_when_already_high() {
996        let mut h = SessionHistory::new();
997        h.push("high");
998        h.push("high");
999        h.push("high");
1000        // All turns already high — no upward delta
1001        assert!(!h.is_escalating());
1002    }
1003
1004    #[test]
1005    fn session_no_escalation_medium_to_medium() {
1006        let mut h = SessionHistory::new();
1007        h.push("low");
1008        h.push("medium");
1009        h.push("medium");
1010        // medium is 1.0; historical mean 0.5 → delta 0.5, but not > 0.5
1011        assert!(!h.is_escalating());
1012    }
1013
1014    #[test]
1015    fn session_escalation_low_to_high_five_turns() {
1016        let mut h = SessionHistory::new();
1017        for _ in 0..4 {
1018            h.push("low");
1019        }
1020        h.push("high");
1021        assert!(h.is_escalating());
1022    }
1023
1024    #[test]
1025    fn session_ring_caps_at_max_turns() {
1026        let mut h = SessionHistory::new();
1027        for _ in 0..SESSION_MAX_TURNS + 5 {
1028            h.push("low");
1029        }
1030        assert_eq!(h.turn_count(), SESSION_MAX_TURNS);
1031    }
1032
1033    #[test]
1034    fn risk_score_mapping() {
1035        assert_eq!(risk_score("low"), 0.0);
1036        assert_eq!(risk_score("medium"), 1.0);
1037        assert_eq!(risk_score("high"), 2.0);
1038        assert_eq!(risk_score("unknown"), 0.0);
1039    }
1040
1041    #[test]
1042    fn rate_limit_allows_up_to_max() {
1043        let limiter = ShardedRateLimiter::new();
1044        let ip: IpAddr = "127.0.0.1".parse().unwrap();
1045        for _ in 0..5 {
1046            assert!(limiter.check(ip, 5));
1047        }
1048        assert!(!limiter.check(ip, 5));
1049    }
1050
1051    #[test]
1052    fn rate_limit_different_ips_are_independent() {
1053        let limiter = ShardedRateLimiter::new();
1054        let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1055        let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1056        for _ in 0..3 {
1057            assert!(limiter.check(ip1, 3));
1058        }
1059        assert!(!limiter.check(ip1, 3));
1060        assert!(limiter.check(ip2, 3));
1061    }
1062
1063    #[test]
1064    fn summarize_result_contains_key_fields() {
1065        use crate::types::*;
1066        let result = HarnessResult {
1067            telemetry: TelemetryResult {
1068                affective_telemetry: AfferentTelemetry {
1069                    primary_emotion: "neutral".into(),
1070                    emotional_intensity: 0.1,
1071                    structural_tone: vec!["analytical".into()],
1072                },
1073                intent_matrix: IntentMatrix {
1074                    stated_objective: "test query".into(),
1075                    subtextual_motive: "none".into(),
1076                    manipulation_risk: "low".into(),
1077                },
1078                cognitive_state: CognitiveState {
1079                    urgency_vector: 0.0,
1080                    coherence_rating: 0.9,
1081                },
1082            },
1083            verification: VerificationReport {
1084                passed: true,
1085                consistency_flags: vec![],
1086                unsupported_claims: vec![],
1087                assumptions: vec![],
1088                unresolved: vec![],
1089                confidence: 0.9,
1090                stop_and_ask: false,
1091            },
1092            trace: vec![],
1093            capability_request: None,
1094            obfuscation: None,
1095        };
1096        let s = summarize_result(&result);
1097        assert!(s.contains("neutral"));
1098        assert!(s.contains("low"));
1099        assert!(s.contains("passed"));
1100    }
1101}