use attuned_core::{
HealthCheck, HealthState, HealthStatus, PromptContext, RuleTranslator, Source, StateSnapshot,
Translator,
};
use attuned_store::StateStore;
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
Json,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
#[cfg(feature = "inference")]
use attuned_infer::{Baseline, InferenceConfig, InferenceEngine, InferenceSource};
#[cfg(feature = "inference")]
use dashmap::DashMap;
#[cfg(feature = "inference")]
use std::collections::HashMap;
pub struct AppState<S: StateStore> {
pub store: Arc<S>,
pub translator: Arc<dyn Translator>,
pub start_time: Instant,
#[cfg(feature = "inference")]
pub inference_engine: Option<InferenceEngine>,
#[cfg(feature = "inference")]
pub baselines: Arc<DashMap<String, Baseline>>,
}
impl<S: StateStore> AppState<S> {
pub fn new(store: S) -> Self {
Self {
store: Arc::new(store),
translator: Arc::new(RuleTranslator::default()),
start_time: Instant::now(),
#[cfg(feature = "inference")]
inference_engine: None,
#[cfg(feature = "inference")]
baselines: Arc::new(DashMap::new()),
}
}
#[cfg(feature = "inference")]
pub fn with_inference(store: S, config: Option<InferenceConfig>) -> Self {
let engine = match config {
Some(c) => InferenceEngine::with_config(c),
None => InferenceEngine::default(),
};
Self {
store: Arc::new(store),
translator: Arc::new(RuleTranslator::default()),
start_time: Instant::now(),
inference_engine: Some(engine),
baselines: Arc::new(DashMap::new()),
}
}
}
#[derive(Debug, Deserialize)]
pub struct UpsertStateRequest {
pub user_id: String,
#[serde(default)]
pub source: SourceInput,
#[serde(default = "default_confidence")]
pub confidence: f32,
pub axes: std::collections::BTreeMap<String, f32>,
#[serde(default)]
pub message: Option<String>,
}
fn default_confidence() -> f32 {
1.0
}
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SourceInput {
#[default]
SelfReport,
Inferred,
Mixed,
}
impl From<SourceInput> for Source {
fn from(s: SourceInput) -> Self {
match s {
SourceInput::SelfReport => Source::SelfReport,
SourceInput::Inferred => Source::Inferred,
SourceInput::Mixed => Source::Mixed,
}
}
}
#[derive(Debug, Serialize)]
pub struct StateResponse {
pub user_id: String,
pub updated_at_unix_ms: i64,
pub source: String,
pub confidence: f32,
pub axes: std::collections::BTreeMap<String, f32>,
}
impl From<StateSnapshot> for StateResponse {
fn from(s: StateSnapshot) -> Self {
Self {
user_id: s.user_id,
updated_at_unix_ms: s.updated_at_unix_ms,
source: s.source.to_string(),
confidence: s.confidence,
axes: s.axes,
}
}
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Serialize)]
pub struct ErrorDetail {
pub code: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
}
impl ErrorResponse {
pub fn new(code: &str, message: &str) -> Self {
Self {
error: ErrorDetail {
code: code.to_string(),
message: message.to_string(),
request_id: None,
},
}
}
}
#[tracing::instrument(skip(state, body))]
#[allow(unused_mut)] pub async fn upsert_state<S: StateStore + 'static>(
State(state): State<Arc<AppState<S>>>,
Json(body): Json<UpsertStateRequest>,
) -> impl IntoResponse {
let mut axes = body.axes;
let mut source: Source = body.source.into();
#[cfg(feature = "inference")]
if let (Some(engine), Some(message)) = (&state.inference_engine, &body.message) {
let mut baseline_ref = state
.baselines
.entry(body.user_id.clone())
.or_insert_with(|| engine.new_baseline());
let inferred = engine.infer_with_baseline(message, &mut baseline_ref, None);
for estimate in inferred.all() {
if !axes.contains_key(&estimate.axis) {
axes.insert(estimate.axis.clone(), estimate.value);
}
}
if !inferred.is_empty() && source == Source::SelfReport {
source = Source::Mixed;
}
}
let snapshot = match StateSnapshot::builder()
.user_id(&body.user_id)
.source(source)
.confidence(body.confidence)
.axes(axes.into_iter())
.build()
{
Ok(s) => s,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse::new("VALIDATION_ERROR", &e.to_string())),
)
.into_response();
}
};
match state.store.upsert_latest(snapshot).await {
Ok(()) => StatusCode::NO_CONTENT.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
)
.into_response(),
}
}
#[tracing::instrument(skip(state))]
pub async fn get_state<S: StateStore + 'static>(
State(state): State<Arc<AppState<S>>>,
Path(user_id): Path<String>,
) -> impl IntoResponse {
match state.store.get_latest(&user_id).await {
Ok(Some(snapshot)) => Json(StateResponse::from(snapshot)).into_response(),
Ok(None) => (
StatusCode::NOT_FOUND,
Json(ErrorResponse::new(
"USER_NOT_FOUND",
&format!("No state found for user {}", user_id),
)),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
)
.into_response(),
}
}
#[tracing::instrument(skip(state))]
pub async fn delete_state<S: StateStore + 'static>(
State(state): State<Arc<AppState<S>>>,
Path(user_id): Path<String>,
) -> impl IntoResponse {
match state.store.delete(&user_id).await {
Ok(()) => StatusCode::NO_CONTENT.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
)
.into_response(),
}
}
#[tracing::instrument(skip(state))]
pub async fn get_context<S: StateStore + 'static>(
State(state): State<Arc<AppState<S>>>,
Path(user_id): Path<String>,
) -> impl IntoResponse {
match state.store.get_latest(&user_id).await {
Ok(Some(snapshot)) => {
let context = state.translator.to_prompt_context(&snapshot);
Json(context).into_response()
}
Ok(None) => (
StatusCode::NOT_FOUND,
Json(ErrorResponse::new(
"USER_NOT_FOUND",
&format!("No state found for user {}", user_id),
)),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::new("STORE_ERROR", &e.to_string())),
)
.into_response(),
}
}
#[derive(Debug, Deserialize)]
pub struct TranslateRequest {
pub axes: std::collections::BTreeMap<String, f32>,
#[serde(default)]
pub source: SourceInput,
#[serde(default = "default_confidence")]
pub confidence: f32,
}
#[tracing::instrument(skip(state, body))]
pub async fn translate<S: StateStore + 'static>(
State(state): State<Arc<AppState<S>>>,
Json(body): Json<TranslateRequest>,
) -> impl IntoResponse {
let snapshot = match StateSnapshot::builder()
.user_id("_anonymous")
.source(body.source.into())
.confidence(body.confidence)
.axes(body.axes.into_iter())
.build()
{
Ok(s) => s,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ErrorResponse::new("VALIDATION_ERROR", &e.to_string())),
)
.into_response();
}
};
let context = state.translator.to_prompt_context(&snapshot);
Json(context).into_response()
}
#[tracing::instrument(skip(state))]
pub async fn health<S: StateStore + HealthCheck + 'static>(
State(state): State<Arc<AppState<S>>>,
) -> impl IntoResponse {
let store_health = state.store.check().await;
let uptime = state.start_time.elapsed().as_secs();
let status = HealthStatus::from_checks(vec![store_health], uptime);
let status_code = match status.status {
HealthState::Healthy => StatusCode::OK,
HealthState::Degraded => StatusCode::OK,
HealthState::Unhealthy => StatusCode::SERVICE_UNAVAILABLE,
};
(status_code, Json(status))
}
#[tracing::instrument(skip(state))]
pub async fn ready<S: StateStore + 'static>(
State(state): State<Arc<AppState<S>>>,
) -> impl IntoResponse {
match state.store.health_check().await {
Ok(true) => StatusCode::OK,
_ => StatusCode::SERVICE_UNAVAILABLE,
}
}
#[derive(Debug, Serialize)]
pub struct ContextResponse {
pub guidelines: Vec<String>,
pub tone: String,
pub verbosity: String,
pub flags: Vec<String>,
}
impl From<PromptContext> for ContextResponse {
fn from(c: PromptContext) -> Self {
Self {
guidelines: c.guidelines,
tone: c.tone,
verbosity: format!("{:?}", c.verbosity).to_lowercase(),
flags: c.flags,
}
}
}
#[cfg(feature = "inference")]
#[derive(Debug, Deserialize)]
pub struct InferRequest {
pub message: String,
#[serde(default)]
pub user_id: Option<String>,
#[serde(default)]
pub include_features: bool,
}
#[cfg(feature = "inference")]
#[derive(Debug, Serialize)]
pub struct InferEstimate {
pub axis: String,
pub value: f32,
pub confidence: f32,
pub source: InferSourceResponse,
}
#[cfg(feature = "inference")]
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InferSourceResponse {
Linguistic {
features_used: Vec<String>,
},
Delta {
z_score: f32,
metric: String,
},
Combined {
source_count: usize,
},
Prior {
reason: String,
},
}
#[cfg(feature = "inference")]
impl From<&InferenceSource> for InferSourceResponse {
fn from(source: &InferenceSource) -> Self {
match source {
InferenceSource::Linguistic { features_used, .. } => InferSourceResponse::Linguistic {
features_used: features_used.clone(),
},
InferenceSource::Delta {
z_score, metric, ..
} => InferSourceResponse::Delta {
z_score: *z_score,
metric: metric.clone(),
},
InferenceSource::Combined { sources, .. } => InferSourceResponse::Combined {
source_count: sources.len(),
},
InferenceSource::Prior { reason } => InferSourceResponse::Prior {
reason: reason.clone(),
},
InferenceSource::Decayed { original, .. } => {
InferSourceResponse::from(original.as_ref())
}
InferenceSource::SelfReport => {
InferSourceResponse::Prior {
reason: "self_report".into(),
}
}
}
}
}
#[cfg(feature = "inference")]
#[derive(Debug, Serialize)]
pub struct InferResponse {
pub estimates: Vec<InferEstimate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub features: Option<HashMap<String, serde_json::Value>>,
}
#[cfg(feature = "inference")]
#[tracing::instrument(skip(state, body))]
pub async fn infer<S: StateStore + 'static>(
State(state): State<Arc<AppState<S>>>,
Json(body): Json<InferRequest>,
) -> impl IntoResponse {
let Some(engine) = &state.inference_engine else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse::new(
"INFERENCE_DISABLED",
"Inference is not enabled on this server",
)),
)
.into_response();
};
let inferred = if let Some(user_id) = &body.user_id {
let mut baseline_ref = state
.baselines
.entry(user_id.clone())
.or_insert_with(|| engine.new_baseline());
engine.infer_with_baseline(&body.message, &mut baseline_ref, None)
} else {
engine.infer(&body.message)
};
let estimates: Vec<InferEstimate> = inferred
.all()
.map(|est| InferEstimate {
axis: est.axis.clone(),
value: est.value,
confidence: est.confidence,
source: InferSourceResponse::from(&est.source),
})
.collect();
let features = if body.include_features {
let extractor = attuned_infer::LinguisticExtractor::new();
let f = extractor.extract(&body.message);
let mut map = HashMap::new();
map.insert("word_count".into(), serde_json::json!(f.word_count));
map.insert("sentence_count".into(), serde_json::json!(f.sentence_count));
map.insert("hedge_count".into(), serde_json::json!(f.hedge_count));
map.insert(
"urgency_word_count".into(),
serde_json::json!(f.urgency_word_count),
);
map.insert(
"negative_emotion_count".into(),
serde_json::json!(f.negative_emotion_count),
);
map.insert(
"exclamation_ratio".into(),
serde_json::json!(f.exclamation_ratio),
);
map.insert("question_ratio".into(), serde_json::json!(f.question_ratio));
map.insert("caps_ratio".into(), serde_json::json!(f.caps_ratio));
map.insert(
"first_person_ratio".into(),
serde_json::json!(f.first_person_ratio),
);
Some(map)
} else {
None
};
Json(InferResponse {
estimates,
features,
})
.into_response()
}