attuned_http/
handlers.rs

1//! HTTP request handlers.
2
3use attuned_core::{
4    HealthCheck, HealthState, HealthStatus, PromptContext, RuleTranslator, Source, StateSnapshot,
5    Translator,
6};
7use attuned_store::StateStore;
8use axum::{
9    extract::{Path, State},
10    http::StatusCode,
11    response::IntoResponse,
12    Json,
13};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use std::time::Instant;
17
18#[cfg(feature = "inference")]
19use attuned_infer::{Baseline, InferenceConfig, InferenceEngine, InferenceSource};
20#[cfg(feature = "inference")]
21use dashmap::DashMap;
22#[cfg(feature = "inference")]
23use std::collections::HashMap;
24
25/// Application state shared across handlers.
26pub struct AppState<S: StateStore> {
27    /// The state store backend.
28    pub store: Arc<S>,
29    /// The translator for converting state to context.
30    pub translator: Arc<dyn Translator>,
31    /// Server start time for uptime calculation.
32    pub start_time: Instant,
33    /// Inference engine (optional, requires "inference" feature).
34    #[cfg(feature = "inference")]
35    pub inference_engine: Option<InferenceEngine>,
36    /// Per-user baselines for delta analysis.
37    #[cfg(feature = "inference")]
38    pub baselines: Arc<DashMap<String, Baseline>>,
39}
40
41impl<S: StateStore> AppState<S> {
42    /// Create new application state.
43    pub fn new(store: S) -> Self {
44        Self {
45            store: Arc::new(store),
46            translator: Arc::new(RuleTranslator::default()),
47            start_time: Instant::now(),
48            #[cfg(feature = "inference")]
49            inference_engine: None,
50            #[cfg(feature = "inference")]
51            baselines: Arc::new(DashMap::new()),
52        }
53    }
54
55    /// Create application state with inference enabled.
56    #[cfg(feature = "inference")]
57    pub fn with_inference(store: S, config: Option<InferenceConfig>) -> Self {
58        let engine = match config {
59            Some(c) => InferenceEngine::with_config(c),
60            None => InferenceEngine::default(),
61        };
62        Self {
63            store: Arc::new(store),
64            translator: Arc::new(RuleTranslator::default()),
65            start_time: Instant::now(),
66            inference_engine: Some(engine),
67            baselines: Arc::new(DashMap::new()),
68        }
69    }
70}
71
72/// Request body for upserting state.
73#[derive(Debug, Deserialize)]
74pub struct UpsertStateRequest {
75    /// User ID to update state for.
76    pub user_id: String,
77    /// Source of the state data.
78    #[serde(default)]
79    pub source: SourceInput,
80    /// Confidence level of the state data.
81    #[serde(default = "default_confidence")]
82    pub confidence: f32,
83    /// Axis values to set.
84    pub axes: std::collections::BTreeMap<String, f32>,
85    /// Optional message text for inference (requires "inference" feature).
86    /// When provided, axes are inferred from the message and merged with explicit axes.
87    /// Explicit axes always override inferred values.
88    #[serde(default)]
89    pub message: Option<String>,
90}
91
92fn default_confidence() -> f32 {
93    1.0
94}
95
96/// Source of state data in API requests.
97#[derive(Debug, Default, Deserialize)]
98#[serde(rename_all = "snake_case")]
99pub enum SourceInput {
100    /// User explicitly provided this state.
101    #[default]
102    SelfReport,
103    /// State was inferred from behavior.
104    Inferred,
105    /// Combination of self-report and inference.
106    Mixed,
107}
108
109impl From<SourceInput> for Source {
110    fn from(s: SourceInput) -> Self {
111        match s {
112            SourceInput::SelfReport => Source::SelfReport,
113            SourceInput::Inferred => Source::Inferred,
114            SourceInput::Mixed => Source::Mixed,
115        }
116    }
117}
118
119/// Response for state operations.
120#[derive(Debug, Serialize)]
121pub struct StateResponse {
122    /// User ID.
123    pub user_id: String,
124    /// Timestamp of last update (Unix ms).
125    pub updated_at_unix_ms: i64,
126    /// Source of the state data.
127    pub source: String,
128    /// Confidence level.
129    pub confidence: f32,
130    /// Axis values.
131    pub axes: std::collections::BTreeMap<String, f32>,
132}
133
134impl From<StateSnapshot> for StateResponse {
135    fn from(s: StateSnapshot) -> Self {
136        Self {
137            user_id: s.user_id,
138            updated_at_unix_ms: s.updated_at_unix_ms,
139            source: s.source.to_string(),
140            confidence: s.confidence,
141            axes: s.axes,
142        }
143    }
144}
145
146/// Error response format.
147#[derive(Debug, Serialize)]
148pub struct ErrorResponse {
149    /// Error details.
150    pub error: ErrorDetail,
151}
152
153/// Detailed error information.
154#[derive(Debug, Serialize)]
155pub struct ErrorDetail {
156    /// Error code.
157    pub code: String,
158    /// Human-readable error message.
159    pub message: String,
160    /// Request ID for correlation.
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub request_id: Option<String>,
163}
164
165impl ErrorResponse {
166    /// Create a new error response.
167    pub fn new(code: &str, message: &str) -> Self {
168        Self {
169            error: ErrorDetail {
170                code: code.to_string(),
171                message: message.to_string(),
172                request_id: None,
173            },
174        }
175    }
176}
177
178/// POST /v1/state - Upsert state
179#[tracing::instrument(skip(state, body))]
180#[allow(unused_mut)] // mut needed when inference feature is enabled
181pub async fn upsert_state<S: StateStore + 'static>(
182    State(state): State<Arc<AppState<S>>>,
183    Json(body): Json<UpsertStateRequest>,
184) -> impl IntoResponse {
185    let mut axes = body.axes;
186    let mut source: Source = body.source.into();
187
188    // Run inference if enabled and message provided
189    #[cfg(feature = "inference")]
190    if let (Some(engine), Some(message)) = (&state.inference_engine, &body.message) {
191        // Get or create baseline for user
192        let mut baseline_ref = state
193            .baselines
194            .entry(body.user_id.clone())
195            .or_insert_with(|| engine.new_baseline());
196
197        // Run inference with baseline
198        let inferred = engine.infer_with_baseline(message, &mut baseline_ref, None);
199
200        // Merge: explicit axes override inferred
201        for estimate in inferred.all() {
202            if !axes.contains_key(&estimate.axis) {
203                // Only use inferred if not explicitly provided
204                axes.insert(estimate.axis.clone(), estimate.value);
205            }
206        }
207
208        // Mark source as mixed if we used inference
209        if !inferred.is_empty() && source == Source::SelfReport {
210            source = Source::Mixed;
211        }
212    }
213
214    let snapshot = match StateSnapshot::builder()
215        .user_id(&body.user_id)
216        .source(source)
217        .confidence(body.confidence)
218        .axes(axes.into_iter())
219        .build()
220    {
221        Ok(s) => s,
222        Err(e) => {
223            return (
224                StatusCode::BAD_REQUEST,
225                Json(ErrorResponse::new("VALIDATION_ERROR", &e.to_string())),
226            )
227                .into_response();
228        }
229    };
230
231    match state.store.upsert_latest(snapshot).await {
232        Ok(()) => StatusCode::NO_CONTENT.into_response(),
233        Err(e) => (
234            StatusCode::INTERNAL_SERVER_ERROR,
235            Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
236        )
237            .into_response(),
238    }
239}
240
241/// GET /v1/state/:user_id - Get state
242#[tracing::instrument(skip(state))]
243pub async fn get_state<S: StateStore + 'static>(
244    State(state): State<Arc<AppState<S>>>,
245    Path(user_id): Path<String>,
246) -> impl IntoResponse {
247    match state.store.get_latest(&user_id).await {
248        Ok(Some(snapshot)) => Json(StateResponse::from(snapshot)).into_response(),
249        Ok(None) => (
250            StatusCode::NOT_FOUND,
251            Json(ErrorResponse::new(
252                "USER_NOT_FOUND",
253                &format!("No state found for user {}", user_id),
254            )),
255        )
256            .into_response(),
257        Err(e) => (
258            StatusCode::INTERNAL_SERVER_ERROR,
259            Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
260        )
261            .into_response(),
262    }
263}
264
265/// DELETE /v1/state/:user_id - Delete state
266#[tracing::instrument(skip(state))]
267pub async fn delete_state<S: StateStore + 'static>(
268    State(state): State<Arc<AppState<S>>>,
269    Path(user_id): Path<String>,
270) -> impl IntoResponse {
271    match state.store.delete(&user_id).await {
272        Ok(()) => StatusCode::NO_CONTENT.into_response(),
273        Err(e) => (
274            StatusCode::INTERNAL_SERVER_ERROR,
275            Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
276        )
277            .into_response(),
278    }
279}
280
281/// GET /v1/context/:user_id - Get translated context
282#[tracing::instrument(skip(state))]
283pub async fn get_context<S: StateStore + 'static>(
284    State(state): State<Arc<AppState<S>>>,
285    Path(user_id): Path<String>,
286) -> impl IntoResponse {
287    match state.store.get_latest(&user_id).await {
288        Ok(Some(snapshot)) => {
289            let context = state.translator.to_prompt_context(&snapshot);
290            Json(context).into_response()
291        }
292        Ok(None) => (
293            StatusCode::NOT_FOUND,
294            Json(ErrorResponse::new(
295                "USER_NOT_FOUND",
296                &format!("No state found for user {}", user_id),
297            )),
298        )
299            .into_response(),
300        Err(e) => (
301            StatusCode::INTERNAL_SERVER_ERROR,
302            Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
303        )
304            .into_response(),
305    }
306}
307
308/// Request body for inline translation.
309#[derive(Debug, Deserialize)]
310pub struct TranslateRequest {
311    /// Axis values to translate.
312    pub axes: std::collections::BTreeMap<String, f32>,
313    /// Source of the state data.
314    #[serde(default)]
315    pub source: SourceInput,
316    /// Confidence level.
317    #[serde(default = "default_confidence")]
318    pub confidence: f32,
319}
320
321/// POST /v1/translate - Translate arbitrary state
322#[tracing::instrument(skip(state, body))]
323pub async fn translate<S: StateStore + 'static>(
324    State(state): State<Arc<AppState<S>>>,
325    Json(body): Json<TranslateRequest>,
326) -> impl IntoResponse {
327    let snapshot = match StateSnapshot::builder()
328        .user_id("_anonymous")
329        .source(body.source.into())
330        .confidence(body.confidence)
331        .axes(body.axes.into_iter())
332        .build()
333    {
334        Ok(s) => s,
335        Err(e) => {
336            return (
337                StatusCode::BAD_REQUEST,
338                Json(ErrorResponse::new("VALIDATION_ERROR", &e.to_string())),
339            )
340                .into_response();
341        }
342    };
343
344    let context = state.translator.to_prompt_context(&snapshot);
345    Json(context).into_response()
346}
347
348/// GET /health - Health check
349#[tracing::instrument(skip(state))]
350pub async fn health<S: StateStore + HealthCheck + 'static>(
351    State(state): State<Arc<AppState<S>>>,
352) -> impl IntoResponse {
353    let store_health = state.store.check().await;
354    let uptime = state.start_time.elapsed().as_secs();
355
356    let status = HealthStatus::from_checks(vec![store_health], uptime);
357
358    let status_code = match status.status {
359        HealthState::Healthy => StatusCode::OK,
360        HealthState::Degraded => StatusCode::OK,
361        HealthState::Unhealthy => StatusCode::SERVICE_UNAVAILABLE,
362    };
363
364    (status_code, Json(status))
365}
366
367/// GET /ready - Readiness check
368#[tracing::instrument(skip(state))]
369pub async fn ready<S: StateStore + 'static>(
370    State(state): State<Arc<AppState<S>>>,
371) -> impl IntoResponse {
372    match state.store.health_check().await {
373        Ok(true) => StatusCode::OK,
374        _ => StatusCode::SERVICE_UNAVAILABLE,
375    }
376}
377
378/// Response for prompt context.
379#[derive(Debug, Serialize)]
380pub struct ContextResponse {
381    /// Behavioral guidelines for the LLM.
382    pub guidelines: Vec<String>,
383    /// Suggested tone.
384    pub tone: String,
385    /// Desired response verbosity.
386    pub verbosity: String,
387    /// Active flags for special conditions.
388    pub flags: Vec<String>,
389}
390
391impl From<PromptContext> for ContextResponse {
392    fn from(c: PromptContext) -> Self {
393        Self {
394            guidelines: c.guidelines,
395            tone: c.tone,
396            verbosity: format!("{:?}", c.verbosity).to_lowercase(),
397            flags: c.flags,
398        }
399    }
400}
401
402// ============================================================================
403// Inference endpoint (requires "inference" feature)
404// ============================================================================
405
406/// Request body for inference endpoint.
407#[cfg(feature = "inference")]
408#[derive(Debug, Deserialize)]
409pub struct InferRequest {
410    /// The message text to analyze.
411    pub message: String,
412    /// Optional user ID for baseline comparison.
413    /// If provided, the user's baseline will be updated.
414    #[serde(default)]
415    pub user_id: Option<String>,
416    /// Include debug feature information in response.
417    #[serde(default)]
418    pub include_features: bool,
419}
420
421/// A single axis estimate in the inference response.
422#[cfg(feature = "inference")]
423#[derive(Debug, Serialize)]
424pub struct InferEstimate {
425    /// The axis name.
426    pub axis: String,
427    /// Estimated value in [0.0, 1.0].
428    pub value: f32,
429    /// Confidence in this estimate.
430    pub confidence: f32,
431    /// Source of this inference.
432    pub source: InferSourceResponse,
433}
434
435/// Inference source for API response.
436#[cfg(feature = "inference")]
437#[derive(Debug, Serialize)]
438#[serde(tag = "type", rename_all = "snake_case")]
439pub enum InferSourceResponse {
440    /// Inferred from linguistic features.
441    Linguistic {
442        /// Features that contributed to this inference.
443        features_used: Vec<String>,
444    },
445    /// Inferred from deviation from baseline.
446    Delta {
447        /// Z-score deviation from baseline.
448        z_score: f32,
449        /// Which metric showed deviation.
450        metric: String,
451    },
452    /// Combined from multiple sources.
453    Combined {
454        /// Number of sources combined.
455        source_count: usize,
456    },
457    /// Prior/default value.
458    Prior {
459        /// Reason for this prior.
460        reason: String,
461    },
462}
463
464#[cfg(feature = "inference")]
465impl From<&InferenceSource> for InferSourceResponse {
466    fn from(source: &InferenceSource) -> Self {
467        match source {
468            InferenceSource::Linguistic { features_used, .. } => InferSourceResponse::Linguistic {
469                features_used: features_used.clone(),
470            },
471            InferenceSource::Delta {
472                z_score, metric, ..
473            } => InferSourceResponse::Delta {
474                z_score: *z_score,
475                metric: metric.clone(),
476            },
477            InferenceSource::Combined { sources, .. } => InferSourceResponse::Combined {
478                source_count: sources.len(),
479            },
480            InferenceSource::Prior { reason } => InferSourceResponse::Prior {
481                reason: reason.clone(),
482            },
483            InferenceSource::Decayed { original, .. } => {
484                // Unwrap to original source
485                InferSourceResponse::from(original.as_ref())
486            }
487            InferenceSource::SelfReport => {
488                // Shouldn't happen in inference, but handle gracefully
489                InferSourceResponse::Prior {
490                    reason: "self_report".into(),
491                }
492            }
493        }
494    }
495}
496
497/// Response for inference endpoint.
498#[cfg(feature = "inference")]
499#[derive(Debug, Serialize)]
500pub struct InferResponse {
501    /// Estimated axes.
502    pub estimates: Vec<InferEstimate>,
503    /// Debug feature information (if requested).
504    #[serde(skip_serializing_if = "Option::is_none")]
505    pub features: Option<HashMap<String, serde_json::Value>>,
506}
507
508/// POST /v1/infer - Infer axes from message text without storage
509#[cfg(feature = "inference")]
510#[tracing::instrument(skip(state, body))]
511pub async fn infer<S: StateStore + 'static>(
512    State(state): State<Arc<AppState<S>>>,
513    Json(body): Json<InferRequest>,
514) -> impl IntoResponse {
515    let Some(engine) = &state.inference_engine else {
516        return (
517            StatusCode::SERVICE_UNAVAILABLE,
518            Json(ErrorResponse::new(
519                "INFERENCE_DISABLED",
520                "Inference is not enabled on this server",
521            )),
522        )
523            .into_response();
524    };
525
526    // Run inference with optional baseline
527    let inferred = if let Some(user_id) = &body.user_id {
528        let mut baseline_ref = state
529            .baselines
530            .entry(user_id.clone())
531            .or_insert_with(|| engine.new_baseline());
532        engine.infer_with_baseline(&body.message, &mut baseline_ref, None)
533    } else {
534        engine.infer(&body.message)
535    };
536
537    // Convert to response format
538    let estimates: Vec<InferEstimate> = inferred
539        .all()
540        .map(|est| InferEstimate {
541            axis: est.axis.clone(),
542            value: est.value,
543            confidence: est.confidence,
544            source: InferSourceResponse::from(&est.source),
545        })
546        .collect();
547
548    // Include features if requested
549    let features = if body.include_features {
550        let extractor = attuned_infer::LinguisticExtractor::new();
551        let f = extractor.extract(&body.message);
552        let mut map = HashMap::new();
553        map.insert("word_count".into(), serde_json::json!(f.word_count));
554        map.insert("sentence_count".into(), serde_json::json!(f.sentence_count));
555        map.insert("hedge_count".into(), serde_json::json!(f.hedge_count));
556        map.insert(
557            "urgency_word_count".into(),
558            serde_json::json!(f.urgency_word_count),
559        );
560        map.insert(
561            "negative_emotion_count".into(),
562            serde_json::json!(f.negative_emotion_count),
563        );
564        map.insert(
565            "exclamation_ratio".into(),
566            serde_json::json!(f.exclamation_ratio),
567        );
568        map.insert("question_ratio".into(), serde_json::json!(f.question_ratio));
569        map.insert("caps_ratio".into(), serde_json::json!(f.caps_ratio));
570        map.insert(
571            "first_person_ratio".into(),
572            serde_json::json!(f.first_person_ratio),
573        );
574        Some(map)
575    } else {
576        None
577    };
578
579    Json(InferResponse {
580        estimates,
581        features,
582    })
583    .into_response()
584}