use crate::log_store::LogRecord;
use crate::state::*;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Json,
};
use std::collections::HashMap;
pub fn routes() -> axum::Router<AppState> {
axum::Router::new().route("/api/sync", axum::routing::post(sync))
}
fn check_auth(state: &AppState, headers: &HeaderMap) -> Result<Option<String>, StatusCode> {
let store = state.key_store.read().unwrap();
if !store.is_enabled() {
return Ok(None); }
let provided = headers
.get("X-Api-Key")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
match store.validate(provided) {
Some(name) => Ok(Some(name)),
None => Err(StatusCode::UNAUTHORIZED),
}
}
#[derive(serde::Deserialize)]
pub struct BatchCorrection {
pub namespace: String,
pub query: String,
pub wrong_intent: String,
pub right_intent: String,
}
#[derive(serde::Deserialize)]
pub struct BatchLogEntry {
pub query: String,
pub app_id: String,
#[serde(default)]
pub session_id: Option<String>,
#[serde(default)]
pub detected_intents: Vec<String>,
#[serde(default = "default_confidence")]
pub confidence: String,
pub timestamp_ms: u64,
#[serde(default)]
pub router_version: u64,
}
fn default_confidence() -> String {
"none".to_string()
}
#[derive(serde::Deserialize)]
pub struct SyncBatchRequest {
#[serde(default)]
pub local_versions: HashMap<String, u64>,
#[serde(default)]
pub logs: Vec<BatchLogEntry>,
#[serde(default)]
pub corrections: Vec<BatchCorrection>,
}
pub async fn sync(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<SyncBatchRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let key_name = check_auth(&state, &headers)?;
let source = match key_name {
Some(name) => format!("connected:{}", name),
None => "connected".to_string(),
};
let mut corrections_applied: usize = 0;
for correction in &req.corrections {
if let Some(h) = state.engine.try_namespace(&correction.namespace) {
if h.with_resolver_mut(|r| {
r.correct(
&correction.query,
&correction.wrong_intent,
&correction.right_intent,
)
})
.is_ok()
{
corrections_applied += 1;
maybe_commit(&state, &correction.namespace);
}
}
}
let logs_accepted = req.logs.len();
if !req.logs.is_empty() {
let mut store = state.log_store.lock().unwrap();
for entry in req.logs {
let record = LogRecord {
id: 0, query: entry.query,
app_id: entry.app_id,
detected_intents: entry.detected_intents,
confidence: entry.confidence,
session_id: entry.session_id,
timestamp_ms: entry.timestamp_ms,
router_version: entry.router_version,
source: source.clone(),
};
store.append(record);
}
drop(store);
state.worker_notify.notify_one();
}
let mut namespaces = serde_json::Map::new();
for (ns_id, local_version) in &req.local_versions {
let entry = match state.engine.try_namespace(ns_id) {
None => serde_json::json!({"up_to_date": true, "version": 0}),
Some(h) => {
let server_version = h.with_resolver(|r| r.version());
if server_version == *local_version {
serde_json::json!({"up_to_date": true, "version": server_version})
} else {
let export = h.with_resolver(|r| r.export_json());
serde_json::json!({
"up_to_date": false,
"version": server_version,
"export": export,
})
}
}
};
namespaces.insert(ns_id.clone(), entry);
}
Ok(Json(serde_json::json!({
"namespaces": namespaces,
"logs_accepted": logs_accepted,
"corrections_applied": corrections_applied,
})))
}