1use attuned_core::{
4 HealthCheck, HealthState, HealthStatus, PromptContext, RuleTranslator, Source, StateSnapshot,
5 Translator,
6};
7use attuned_store::StateStore;
8use axum::{
9 extract::{Path, State},
10 http::StatusCode,
11 response::IntoResponse,
12 Json,
13};
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use std::time::Instant;
17
18#[cfg(feature = "inference")]
19use attuned_infer::{Baseline, InferenceConfig, InferenceEngine, InferenceSource};
20#[cfg(feature = "inference")]
21use dashmap::DashMap;
22#[cfg(feature = "inference")]
23use std::collections::HashMap;
24
25pub struct AppState<S: StateStore> {
27 pub store: Arc<S>,
29 pub translator: Arc<dyn Translator>,
31 pub start_time: Instant,
33 #[cfg(feature = "inference")]
35 pub inference_engine: Option<InferenceEngine>,
36 #[cfg(feature = "inference")]
38 pub baselines: Arc<DashMap<String, Baseline>>,
39}
40
41impl<S: StateStore> AppState<S> {
42 pub fn new(store: S) -> Self {
44 Self {
45 store: Arc::new(store),
46 translator: Arc::new(RuleTranslator::default()),
47 start_time: Instant::now(),
48 #[cfg(feature = "inference")]
49 inference_engine: None,
50 #[cfg(feature = "inference")]
51 baselines: Arc::new(DashMap::new()),
52 }
53 }
54
55 #[cfg(feature = "inference")]
57 pub fn with_inference(store: S, config: Option<InferenceConfig>) -> Self {
58 let engine = match config {
59 Some(c) => InferenceEngine::with_config(c),
60 None => InferenceEngine::default(),
61 };
62 Self {
63 store: Arc::new(store),
64 translator: Arc::new(RuleTranslator::default()),
65 start_time: Instant::now(),
66 inference_engine: Some(engine),
67 baselines: Arc::new(DashMap::new()),
68 }
69 }
70}
71
72#[derive(Debug, Deserialize)]
74pub struct UpsertStateRequest {
75 pub user_id: String,
77 #[serde(default)]
79 pub source: SourceInput,
80 #[serde(default = "default_confidence")]
82 pub confidence: f32,
83 pub axes: std::collections::BTreeMap<String, f32>,
85 #[serde(default)]
89 pub message: Option<String>,
90}
91
92fn default_confidence() -> f32 {
93 1.0
94}
95
96#[derive(Debug, Default, Deserialize)]
98#[serde(rename_all = "snake_case")]
99pub enum SourceInput {
100 #[default]
102 SelfReport,
103 Inferred,
105 Mixed,
107}
108
109impl From<SourceInput> for Source {
110 fn from(s: SourceInput) -> Self {
111 match s {
112 SourceInput::SelfReport => Source::SelfReport,
113 SourceInput::Inferred => Source::Inferred,
114 SourceInput::Mixed => Source::Mixed,
115 }
116 }
117}
118
119#[derive(Debug, Serialize)]
121pub struct StateResponse {
122 pub user_id: String,
124 pub updated_at_unix_ms: i64,
126 pub source: String,
128 pub confidence: f32,
130 pub axes: std::collections::BTreeMap<String, f32>,
132}
133
134impl From<StateSnapshot> for StateResponse {
135 fn from(s: StateSnapshot) -> Self {
136 Self {
137 user_id: s.user_id,
138 updated_at_unix_ms: s.updated_at_unix_ms,
139 source: s.source.to_string(),
140 confidence: s.confidence,
141 axes: s.axes,
142 }
143 }
144}
145
146#[derive(Debug, Serialize)]
148pub struct ErrorResponse {
149 pub error: ErrorDetail,
151}
152
153#[derive(Debug, Serialize)]
155pub struct ErrorDetail {
156 pub code: String,
158 pub message: String,
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub request_id: Option<String>,
163}
164
165impl ErrorResponse {
166 pub fn new(code: &str, message: &str) -> Self {
168 Self {
169 error: ErrorDetail {
170 code: code.to_string(),
171 message: message.to_string(),
172 request_id: None,
173 },
174 }
175 }
176}
177
178#[tracing::instrument(skip(state, body))]
180#[allow(unused_mut)] pub async fn upsert_state<S: StateStore + 'static>(
182 State(state): State<Arc<AppState<S>>>,
183 Json(body): Json<UpsertStateRequest>,
184) -> impl IntoResponse {
185 let mut axes = body.axes;
186 let mut source: Source = body.source.into();
187
188 #[cfg(feature = "inference")]
190 if let (Some(engine), Some(message)) = (&state.inference_engine, &body.message) {
191 let mut baseline_ref = state
193 .baselines
194 .entry(body.user_id.clone())
195 .or_insert_with(|| engine.new_baseline());
196
197 let inferred = engine.infer_with_baseline(message, &mut baseline_ref, None);
199
200 for estimate in inferred.all() {
202 if !axes.contains_key(&estimate.axis) {
203 axes.insert(estimate.axis.clone(), estimate.value);
205 }
206 }
207
208 if !inferred.is_empty() && source == Source::SelfReport {
210 source = Source::Mixed;
211 }
212 }
213
214 let snapshot = match StateSnapshot::builder()
215 .user_id(&body.user_id)
216 .source(source)
217 .confidence(body.confidence)
218 .axes(axes.into_iter())
219 .build()
220 {
221 Ok(s) => s,
222 Err(e) => {
223 return (
224 StatusCode::BAD_REQUEST,
225 Json(ErrorResponse::new("VALIDATION_ERROR", &e.to_string())),
226 )
227 .into_response();
228 }
229 };
230
231 match state.store.upsert_latest(snapshot).await {
232 Ok(()) => StatusCode::NO_CONTENT.into_response(),
233 Err(e) => (
234 StatusCode::INTERNAL_SERVER_ERROR,
235 Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
236 )
237 .into_response(),
238 }
239}
240
241#[tracing::instrument(skip(state))]
243pub async fn get_state<S: StateStore + 'static>(
244 State(state): State<Arc<AppState<S>>>,
245 Path(user_id): Path<String>,
246) -> impl IntoResponse {
247 match state.store.get_latest(&user_id).await {
248 Ok(Some(snapshot)) => Json(StateResponse::from(snapshot)).into_response(),
249 Ok(None) => (
250 StatusCode::NOT_FOUND,
251 Json(ErrorResponse::new(
252 "USER_NOT_FOUND",
253 &format!("No state found for user {}", user_id),
254 )),
255 )
256 .into_response(),
257 Err(e) => (
258 StatusCode::INTERNAL_SERVER_ERROR,
259 Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
260 )
261 .into_response(),
262 }
263}
264
265#[tracing::instrument(skip(state))]
267pub async fn delete_state<S: StateStore + 'static>(
268 State(state): State<Arc<AppState<S>>>,
269 Path(user_id): Path<String>,
270) -> impl IntoResponse {
271 match state.store.delete(&user_id).await {
272 Ok(()) => StatusCode::NO_CONTENT.into_response(),
273 Err(e) => (
274 StatusCode::INTERNAL_SERVER_ERROR,
275 Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
276 )
277 .into_response(),
278 }
279}
280
281#[tracing::instrument(skip(state))]
283pub async fn get_context<S: StateStore + 'static>(
284 State(state): State<Arc<AppState<S>>>,
285 Path(user_id): Path<String>,
286) -> impl IntoResponse {
287 match state.store.get_latest(&user_id).await {
288 Ok(Some(snapshot)) => {
289 let context = state.translator.to_prompt_context(&snapshot);
290 Json(context).into_response()
291 }
292 Ok(None) => (
293 StatusCode::NOT_FOUND,
294 Json(ErrorResponse::new(
295 "USER_NOT_FOUND",
296 &format!("No state found for user {}", user_id),
297 )),
298 )
299 .into_response(),
300 Err(e) => (
301 StatusCode::INTERNAL_SERVER_ERROR,
302 Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
303 )
304 .into_response(),
305 }
306}
307
308#[derive(Debug, Deserialize)]
310pub struct TranslateRequest {
311 pub axes: std::collections::BTreeMap<String, f32>,
313 #[serde(default)]
315 pub source: SourceInput,
316 #[serde(default = "default_confidence")]
318 pub confidence: f32,
319}
320
321#[tracing::instrument(skip(state, body))]
323pub async fn translate<S: StateStore + 'static>(
324 State(state): State<Arc<AppState<S>>>,
325 Json(body): Json<TranslateRequest>,
326) -> impl IntoResponse {
327 let snapshot = match StateSnapshot::builder()
328 .user_id("_anonymous")
329 .source(body.source.into())
330 .confidence(body.confidence)
331 .axes(body.axes.into_iter())
332 .build()
333 {
334 Ok(s) => s,
335 Err(e) => {
336 return (
337 StatusCode::BAD_REQUEST,
338 Json(ErrorResponse::new("VALIDATION_ERROR", &e.to_string())),
339 )
340 .into_response();
341 }
342 };
343
344 let context = state.translator.to_prompt_context(&snapshot);
345 Json(context).into_response()
346}
347
348#[tracing::instrument(skip(state))]
350pub async fn health<S: StateStore + HealthCheck + 'static>(
351 State(state): State<Arc<AppState<S>>>,
352) -> impl IntoResponse {
353 let store_health = state.store.check().await;
354 let uptime = state.start_time.elapsed().as_secs();
355
356 let status = HealthStatus::from_checks(vec![store_health], uptime);
357
358 let status_code = match status.status {
359 HealthState::Healthy => StatusCode::OK,
360 HealthState::Degraded => StatusCode::OK,
361 HealthState::Unhealthy => StatusCode::SERVICE_UNAVAILABLE,
362 };
363
364 (status_code, Json(status))
365}
366
367#[tracing::instrument(skip(state))]
369pub async fn ready<S: StateStore + 'static>(
370 State(state): State<Arc<AppState<S>>>,
371) -> impl IntoResponse {
372 match state.store.health_check().await {
373 Ok(true) => StatusCode::OK,
374 _ => StatusCode::SERVICE_UNAVAILABLE,
375 }
376}
377
378#[derive(Debug, Serialize)]
380pub struct ContextResponse {
381 pub guidelines: Vec<String>,
383 pub tone: String,
385 pub verbosity: String,
387 pub flags: Vec<String>,
389}
390
391impl From<PromptContext> for ContextResponse {
392 fn from(c: PromptContext) -> Self {
393 Self {
394 guidelines: c.guidelines,
395 tone: c.tone,
396 verbosity: format!("{:?}", c.verbosity).to_lowercase(),
397 flags: c.flags,
398 }
399 }
400}
401
402#[cfg(feature = "inference")]
408#[derive(Debug, Deserialize)]
409pub struct InferRequest {
410 pub message: String,
412 #[serde(default)]
415 pub user_id: Option<String>,
416 #[serde(default)]
418 pub include_features: bool,
419}
420
421#[cfg(feature = "inference")]
423#[derive(Debug, Serialize)]
424pub struct InferEstimate {
425 pub axis: String,
427 pub value: f32,
429 pub confidence: f32,
431 pub source: InferSourceResponse,
433}
434
435#[cfg(feature = "inference")]
437#[derive(Debug, Serialize)]
438#[serde(tag = "type", rename_all = "snake_case")]
439pub enum InferSourceResponse {
440 Linguistic {
442 features_used: Vec<String>,
444 },
445 Delta {
447 z_score: f32,
449 metric: String,
451 },
452 Combined {
454 source_count: usize,
456 },
457 Prior {
459 reason: String,
461 },
462}
463
464#[cfg(feature = "inference")]
465impl From<&InferenceSource> for InferSourceResponse {
466 fn from(source: &InferenceSource) -> Self {
467 match source {
468 InferenceSource::Linguistic { features_used, .. } => InferSourceResponse::Linguistic {
469 features_used: features_used.clone(),
470 },
471 InferenceSource::Delta {
472 z_score, metric, ..
473 } => InferSourceResponse::Delta {
474 z_score: *z_score,
475 metric: metric.clone(),
476 },
477 InferenceSource::Combined { sources, .. } => InferSourceResponse::Combined {
478 source_count: sources.len(),
479 },
480 InferenceSource::Prior { reason } => InferSourceResponse::Prior {
481 reason: reason.clone(),
482 },
483 InferenceSource::Decayed { original, .. } => {
484 InferSourceResponse::from(original.as_ref())
486 }
487 InferenceSource::SelfReport => {
488 InferSourceResponse::Prior {
490 reason: "self_report".into(),
491 }
492 }
493 }
494 }
495}
496
497#[cfg(feature = "inference")]
499#[derive(Debug, Serialize)]
500pub struct InferResponse {
501 pub estimates: Vec<InferEstimate>,
503 #[serde(skip_serializing_if = "Option::is_none")]
505 pub features: Option<HashMap<String, serde_json::Value>>,
506}
507
508#[cfg(feature = "inference")]
510#[tracing::instrument(skip(state, body))]
511pub async fn infer<S: StateStore + 'static>(
512 State(state): State<Arc<AppState<S>>>,
513 Json(body): Json<InferRequest>,
514) -> impl IntoResponse {
515 let Some(engine) = &state.inference_engine else {
516 return (
517 StatusCode::SERVICE_UNAVAILABLE,
518 Json(ErrorResponse::new(
519 "INFERENCE_DISABLED",
520 "Inference is not enabled on this server",
521 )),
522 )
523 .into_response();
524 };
525
526 let inferred = if let Some(user_id) = &body.user_id {
528 let mut baseline_ref = state
529 .baselines
530 .entry(user_id.clone())
531 .or_insert_with(|| engine.new_baseline());
532 engine.infer_with_baseline(&body.message, &mut baseline_ref, None)
533 } else {
534 engine.infer(&body.message)
535 };
536
537 let estimates: Vec<InferEstimate> = inferred
539 .all()
540 .map(|est| InferEstimate {
541 axis: est.axis.clone(),
542 value: est.value,
543 confidence: est.confidence,
544 source: InferSourceResponse::from(&est.source),
545 })
546 .collect();
547
548 let features = if body.include_features {
550 let extractor = attuned_infer::LinguisticExtractor::new();
551 let f = extractor.extract(&body.message);
552 let mut map = HashMap::new();
553 map.insert("word_count".into(), serde_json::json!(f.word_count));
554 map.insert("sentence_count".into(), serde_json::json!(f.sentence_count));
555 map.insert("hedge_count".into(), serde_json::json!(f.hedge_count));
556 map.insert(
557 "urgency_word_count".into(),
558 serde_json::json!(f.urgency_word_count),
559 );
560 map.insert(
561 "negative_emotion_count".into(),
562 serde_json::json!(f.negative_emotion_count),
563 );
564 map.insert(
565 "exclamation_ratio".into(),
566 serde_json::json!(f.exclamation_ratio),
567 );
568 map.insert("question_ratio".into(), serde_json::json!(f.question_ratio));
569 map.insert("caps_ratio".into(), serde_json::json!(f.caps_ratio));
570 map.insert(
571 "first_person_ratio".into(),
572 serde_json::json!(f.first_person_ratio),
573 );
574 Some(map)
575 } else {
576 None
577 };
578
579 Json(InferResponse {
580 estimates,
581 features,
582 })
583 .into_response()
584}