use crate::log_store::LogRecord;
use crate::state::*;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Json,
};
use std::collections::HashMap;
pub fn routes() -> axum::Router<AppState> {
axum::Router::new()
.route("/api/sync", axum::routing::post(sync))
.route(
"/api/connected_clients",
axum::routing::get(connected_clients),
)
}
fn check_auth(state: &AppState, headers: &HeaderMap) -> Result<Option<String>, StatusCode> {
let store = state.key_store.read().unwrap();
if !store.is_enabled() {
return Ok(None); }
let provided = headers
.get("X-Api-Key")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
match store.validate(provided) {
Some(name) => Ok(Some(name)),
None => Err(StatusCode::UNAUTHORIZED),
}
}
#[derive(serde::Deserialize)]
pub struct BatchCorrection {
pub namespace: String,
pub query: String,
pub wrong_intent: String,
pub right_intent: String,
}
#[derive(serde::Deserialize)]
pub struct BatchLogEntry {
pub query: String,
pub app_id: String,
#[serde(default)]
pub session_id: Option<String>,
#[serde(default)]
pub detected_intents: Vec<String>,
#[serde(default = "default_confidence")]
pub confidence: String,
pub timestamp_ms: u64,
#[serde(default)]
pub router_version: u64,
}
fn default_confidence() -> String {
"none".to_string()
}
#[derive(serde::Deserialize)]
pub struct SyncBatchRequest {
#[serde(default)]
pub local_versions: HashMap<String, u64>,
#[serde(default)]
pub logs: Vec<BatchLogEntry>,
#[serde(default)]
pub corrections: Vec<BatchCorrection>,
#[serde(default)]
pub tick_interval_secs: Option<u32>,
#[serde(default)]
pub library_version: Option<String>,
}
pub async fn sync(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<SyncBatchRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let key_name = check_auth(&state, &headers)?;
if let Some(ref name) = key_name {
let mut subscribed: Vec<String> = req.local_versions.keys().cloned().collect();
subscribed.sort();
let entry = ConnectedClient {
name: name.clone(),
namespaces: subscribed,
tick_interval_secs: req.tick_interval_secs.unwrap_or(30),
library_version: req.library_version.clone(),
last_seen_ms: now_ms(),
};
state
.connected_clients
.write()
.unwrap()
.insert(name.clone(), entry);
}
let source = match key_name {
Some(name) => format!("connected:{}", name),
None => "connected".to_string(),
};
let mut corrections_applied: usize = 0;
for correction in &req.corrections {
if let Some(h) = state.engine.try_namespace(&correction.namespace) {
if h.with_resolver_mut(|r| {
r.correct(
&correction.query,
&correction.wrong_intent,
&correction.right_intent,
)
})
.is_ok()
{
corrections_applied += 1;
maybe_commit(&state, &correction.namespace);
}
}
}
let logs_accepted = req.logs.len();
if !req.logs.is_empty() {
let mut store = state.log_store.lock().unwrap();
for entry in req.logs {
let record = LogRecord {
id: 0, query: entry.query,
app_id: entry.app_id,
detected_intents: entry.detected_intents,
confidence: entry.confidence,
session_id: entry.session_id,
timestamp_ms: entry.timestamp_ms,
router_version: entry.router_version,
source: source.clone(),
};
store.append(record);
}
drop(store);
state.worker_notify.notify_one();
}
let mut namespaces = serde_json::Map::new();
for (ns_id, local_version) in &req.local_versions {
let entry = match state.engine.try_namespace(ns_id) {
None => serde_json::json!({"up_to_date": true, "version": 0}),
Some(h) => {
let server_version = h.with_resolver(|r| r.version());
if server_version == *local_version {
serde_json::json!({"up_to_date": true, "version": server_version})
} else {
let export = h.with_resolver(|r| r.export_json());
serde_json::json!({
"up_to_date": false,
"version": server_version,
"export": export,
})
}
}
};
namespaces.insert(ns_id.clone(), entry);
}
Ok(Json(serde_json::json!({
"namespaces": namespaces,
"logs_accepted": logs_accepted,
"corrections_applied": corrections_applied,
})))
}
pub async fn connected_clients(State(state): State<AppState>) -> Json<serde_json::Value> {
let now = now_ms();
let mut clients = state.connected_clients.write().unwrap();
clients.retain(|_, c| {
let stale_after_ms = (c.tick_interval_secs as u64) * 2 * 1000;
now.saturating_sub(c.last_seen_ms) <= stale_after_ms
});
let items: Vec<serde_json::Value> = clients
.values()
.map(|c| {
let age_ms = now.saturating_sub(c.last_seen_ms);
let stale_after_ms = (c.tick_interval_secs as u64) * 2 * 1000;
let expires_in_ms = stale_after_ms.saturating_sub(age_ms);
serde_json::json!({
"name": c.name,
"namespaces": c.namespaces,
"tick_interval_secs": c.tick_interval_secs,
"library_version": c.library_version,
"last_seen_ms": c.last_seen_ms,
"age_ms": age_ms,
"expires_in_ms": expires_in_ms,
})
})
.collect();
Json(serde_json::json!({
"count": items.len(),
"clients": items,
}))
}