use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Notify;
pub use tokio_util::sync::CancellationToken;
use crate::bridge::protocol::{LogSource, MetricMode};
use crate::webhook::{WebhookEventType, WebhookSender};
pub const STREAM_CHANNEL_CAPACITY: usize = 1024;
pub const DEFAULT_STREAM_HISTORY_CAPACITY: usize = 1024;
const STREAM_HISTORY_CAPACITY_ENV: &str = "COG_STREAM_HISTORY_CAPACITY";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PredictionStatus {
Starting,
Processing,
Succeeded,
Failed,
Canceled,
}
impl PredictionStatus {
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Succeeded | Self::Failed | Self::Canceled)
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Starting => "starting",
Self::Processing => "processing",
Self::Succeeded => "succeeded",
Self::Failed => "failed",
Self::Canceled => "canceled",
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
#[serde(untagged)]
pub enum PredictionOutput {
Single(serde_json::Value),
Stream(Vec<serde_json::Value>),
}
impl PredictionOutput {
pub fn is_stream(&self) -> bool {
matches!(self, PredictionOutput::Stream(_))
}
pub fn into_values(self) -> Vec<serde_json::Value> {
match self {
PredictionOutput::Single(v) => vec![v],
PredictionOutput::Stream(v) => v,
}
}
pub fn final_value(&self) -> &serde_json::Value {
match self {
PredictionOutput::Single(v) => v,
PredictionOutput::Stream(v) => v.last().unwrap_or(&serde_json::Value::Null),
}
}
}
#[derive(Debug, Clone)]
pub enum PredictionStreamEvent {
Start {
id: String,
status: String,
},
Output {
chunk: serde_json::Value,
index: u64,
},
Log {
source: LogSource,
data: String,
},
Metric {
name: String,
value: serde_json::Value,
mode: MetricMode,
},
Completed {
payload: serde_json::Value,
},
}
pub type SharedPredictionStreamEvent = Arc<PredictionStreamEvent>;
pub struct PredictionStreamReplay {
pub replay: VecDeque<SharedPredictionStreamEvent>,
pub skipped: u64,
pub receiver: tokio::sync::broadcast::Receiver<SharedPredictionStreamEvent>,
}
impl PredictionStreamEvent {
pub fn event_name(&self) -> &'static str {
match self {
Self::Start { .. } => "start",
Self::Output { .. } => "output",
Self::Log { .. } => "log",
Self::Metric { .. } => "metric",
Self::Completed { .. } => "completed",
}
}
pub fn json_data(&self) -> serde_json::Value {
match self {
Self::Start { id, status } => serde_json::json!({
"id": id,
"status": status,
}),
Self::Output { chunk, index } => serde_json::json!({
"chunk": chunk,
"index": index,
}),
Self::Log { source, data } => serde_json::json!({
"source": source,
"data": data,
}),
Self::Metric { name, value, mode } => serde_json::json!({
"name": name,
"value": value,
"mode": mode,
}),
Self::Completed { payload } => payload.clone(),
}
}
}
pub struct Prediction {
id: String,
cancel_token: CancellationToken,
started_at: Instant,
status: PredictionStatus,
logs: String,
outputs: Vec<serde_json::Value>,
output: Option<PredictionOutput>,
error: Option<String>,
webhook: Option<WebhookSender>,
completion: Arc<Notify>,
stream_tx: tokio::sync::broadcast::Sender<SharedPredictionStreamEvent>,
stream_history: VecDeque<SharedPredictionStreamEvent>,
stream_history_capacity: usize,
stream_history_skipped: u64,
metrics: HashMap<String, serde_json::Value>,
}
impl Prediction {
pub fn new(id: String, webhook: Option<WebhookSender>) -> Self {
let (stream_tx, _) = tokio::sync::broadcast::channel(STREAM_CHANNEL_CAPACITY);
let stream_history_capacity = stream_history_capacity_from_env();
Self {
id,
cancel_token: CancellationToken::new(),
started_at: Instant::now(),
status: PredictionStatus::Starting,
logs: String::new(),
outputs: Vec::new(),
output: None,
error: None,
webhook,
completion: Arc::new(Notify::new()),
stream_tx,
stream_history: VecDeque::new(),
stream_history_capacity,
stream_history_skipped: 0,
metrics: HashMap::new(),
}
}
pub fn id(&self) -> &str {
&self.id
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
pub fn subscribe_stream(
&self,
) -> tokio::sync::broadcast::Receiver<SharedPredictionStreamEvent> {
self.stream_tx.subscribe()
}
pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay {
PredictionStreamReplay {
replay: self.stream_history.clone(),
skipped: self.stream_history_skipped,
receiver: self.stream_tx.subscribe(),
}
}
pub fn stream_receiver_count(&self) -> usize {
self.stream_tx.receiver_count()
}
fn emit_stream_event(&mut self, event: PredictionStreamEvent) {
let event = Arc::new(event);
if self.stream_history_capacity > 0 {
if self.stream_history.len() == self.stream_history_capacity {
self.stream_history.pop_front();
self.stream_history_skipped += 1;
}
self.stream_history.push_back(Arc::clone(&event));
}
let _ = self.stream_tx.send(event);
}
}
fn stream_history_capacity_from_env() -> usize {
match std::env::var(STREAM_HISTORY_CAPACITY_ENV) {
Ok(value) => match value.parse::<usize>() {
Ok(capacity) => capacity,
Err(error) => {
tracing::warn!(
env_var = STREAM_HISTORY_CAPACITY_ENV,
value,
error = %error,
default = DEFAULT_STREAM_HISTORY_CAPACITY,
"Invalid stream history capacity; using default"
);
DEFAULT_STREAM_HISTORY_CAPACITY
}
},
Err(std::env::VarError::NotPresent) => DEFAULT_STREAM_HISTORY_CAPACITY,
Err(error) => {
tracing::warn!(
env_var = STREAM_HISTORY_CAPACITY_ENV,
error = %error,
default = DEFAULT_STREAM_HISTORY_CAPACITY,
"Invalid stream history capacity; using default"
);
DEFAULT_STREAM_HISTORY_CAPACITY
}
}
}
impl Prediction {
pub fn is_canceled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub fn status(&self) -> PredictionStatus {
self.status
}
pub fn is_terminal(&self) -> bool {
self.status.is_terminal()
}
pub fn set_processing(&mut self) {
self.status = PredictionStatus::Processing;
self.emit_stream_event(PredictionStreamEvent::Start {
id: self.id.clone(),
status: self.status.as_str().to_string(),
});
self.fire_webhook(WebhookEventType::Start);
}
pub fn set_succeeded(&mut self, output: PredictionOutput) {
if self.status.is_terminal() {
return;
}
self.status = PredictionStatus::Succeeded;
self.output = Some(output);
self.emit_stream_event(PredictionStreamEvent::Completed {
payload: self.build_state_snapshot(),
});
self.fire_terminal_webhook();
self.completion.notify_one();
}
pub fn set_failed(&mut self, error: String) {
if self.status.is_terminal() {
return;
}
self.status = PredictionStatus::Failed;
self.error = Some(error);
self.emit_stream_event(PredictionStreamEvent::Completed {
payload: self.build_state_snapshot(),
});
self.fire_terminal_webhook();
self.completion.notify_one();
}
pub fn set_canceled(&mut self) {
if self.status.is_terminal() {
return;
}
self.status = PredictionStatus::Canceled;
self.emit_stream_event(PredictionStreamEvent::Completed {
payload: self.build_state_snapshot(),
});
self.fire_terminal_webhook();
self.completion.notify_one();
}
pub fn elapsed(&self) -> std::time::Duration {
self.started_at.elapsed()
}
pub fn append_log(&mut self, data: &str) {
self.append_log_source(LogSource::Stdout, data);
}
pub fn append_log_source(&mut self, source: LogSource, data: &str) {
self.logs.push_str(data);
self.emit_stream_event(PredictionStreamEvent::Log {
source,
data: data.to_string(),
});
self.fire_webhook(WebhookEventType::Logs);
}
pub fn logs(&self) -> &str {
&self.logs
}
pub fn set_metric(&mut self, name: String, value: serde_json::Value, mode: MetricMode) {
if name.is_empty() || name.split('.').any(|s| s.is_empty()) {
tracing::warn!(key = %name, "Ignoring metric with empty key or empty dot-path segment");
return;
}
self.emit_stream_event(PredictionStreamEvent::Metric {
name: name.clone(),
value: value.clone(),
mode,
});
let parts: Vec<&str> = name.split('.').collect();
if parts.len() > 1 {
self.set_metric_dotpath(&parts, value, mode);
return;
}
match mode {
MetricMode::Replace => {
if value.is_null() {
self.metrics.remove(&name);
} else {
self.metrics.insert(name, value);
}
}
MetricMode::Increment => {
let entry = self.metrics.entry(name).or_insert(serde_json::json!(0));
if let (Some(a), Some(b)) = (entry.as_i64(), value.as_i64()) {
*entry = serde_json::json!(a.wrapping_add(b));
} else if let (Some(a), Some(b)) = (entry.as_u64(), value.as_u64()) {
*entry = serde_json::json!(a.wrapping_add(b));
} else if let (Some(a), Some(b)) = (entry.as_f64(), value.as_f64()) {
*entry = serde_json::json!(a + b);
}
}
MetricMode::Append => {
let entry = self
.metrics
.entry(name)
.or_insert(serde_json::Value::Array(vec![]));
if let Some(arr) = entry.as_array_mut() {
arr.push(value);
} else {
let existing = entry.take();
*entry = serde_json::json!([existing, value]);
}
}
}
}
fn set_metric_dotpath(&mut self, parts: &[&str], value: serde_json::Value, mode: MetricMode) {
debug_assert!(parts.len() > 1);
let root_key = parts[0].to_string();
let entry = self
.metrics
.entry(root_key)
.or_insert_with(|| serde_json::json!({}));
let mut current = entry;
for &part in &parts[1..parts.len() - 1] {
if !current.is_object() {
*current = serde_json::json!({});
}
current = current
.as_object_mut()
.unwrap()
.entry(part)
.or_insert_with(|| serde_json::json!({}));
}
let leaf_key = parts[parts.len() - 1];
if !current.is_object() {
*current = serde_json::json!({});
}
let obj = current.as_object_mut().unwrap();
match mode {
MetricMode::Replace => {
if value.is_null() {
obj.remove(leaf_key);
} else {
obj.insert(leaf_key.to_string(), value);
}
}
MetricMode::Increment => {
let entry = obj.entry(leaf_key).or_insert(serde_json::json!(0));
if let (Some(a), Some(b)) = (entry.as_i64(), value.as_i64()) {
*entry = serde_json::json!(a.wrapping_add(b));
} else if let (Some(a), Some(b)) = (entry.as_u64(), value.as_u64()) {
*entry = serde_json::json!(a.wrapping_add(b));
} else if let (Some(a), Some(b)) = (entry.as_f64(), value.as_f64()) {
*entry = serde_json::json!(a + b);
}
}
MetricMode::Append => {
let entry = obj
.entry(leaf_key)
.or_insert(serde_json::Value::Array(vec![]));
if let Some(arr) = entry.as_array_mut() {
arr.push(value);
} else {
let existing = entry.take();
*entry = serde_json::json!([existing, value]);
}
}
}
}
pub fn metrics(&self) -> &HashMap<String, serde_json::Value> {
&self.metrics
}
pub fn append_output(&mut self, output: serde_json::Value) {
let index = self.outputs.len() as u64;
self.append_output_chunk(output, index);
}
pub fn append_output_chunk(&mut self, output: serde_json::Value, index: u64) {
self.outputs.push(output.clone());
self.emit_stream_event(PredictionStreamEvent::Output {
chunk: output,
index,
});
self.fire_webhook(WebhookEventType::Output);
}
pub fn outputs(&self) -> &[serde_json::Value] {
&self.outputs
}
pub fn take_outputs(&mut self) -> Vec<serde_json::Value> {
std::mem::take(&mut self.outputs)
}
pub fn output(&self) -> Option<&PredictionOutput> {
self.output.as_ref()
}
pub fn error(&self) -> Option<&str> {
self.error.as_deref()
}
pub async fn wait(&self) {
if self.status.is_terminal() {
return;
}
self.completion.notified().await;
}
pub fn completion(&self) -> Arc<Notify> {
Arc::clone(&self.completion)
}
pub fn take_webhook(&mut self) -> Option<WebhookSender> {
self.webhook.take()
}
fn fire_webhook(&self, event: WebhookEventType) {
if let Some(ref webhook) = self.webhook {
let payload = self.build_webhook_payload();
webhook.send(event, &payload);
}
}
fn fire_terminal_webhook(&mut self) {
if let Some(webhook) = self.webhook.take() {
let payload = self.build_webhook_payload();
tokio::spawn(async move {
webhook
.send_terminal(WebhookEventType::Completed, &payload)
.await;
});
}
}
pub fn build_state_snapshot(&self) -> serde_json::Value {
let mut payload = serde_json::json!({
"id": self.id,
"status": self.status.as_str(),
"logs": self.logs,
});
if let Some(ref output) = self.output {
payload["output"] = serde_json::json!(output);
} else if !self.outputs.is_empty() {
payload["output"] = serde_json::json!(self.outputs);
}
if let Some(ref error) = self.error {
payload["error"] = serde_json::Value::String(error.clone());
}
let mut metrics_obj = serde_json::Map::new();
for (k, v) in &self.metrics {
metrics_obj.insert(k.clone(), v.clone());
}
if self.status.is_terminal() {
let predict_time = self.elapsed().as_secs_f64();
metrics_obj.insert("predict_time".to_string(), serde_json::json!(predict_time));
}
payload["metrics"] = serde_json::Value::Object(metrics_obj);
payload
}
fn build_webhook_payload(&self) -> serde_json::Value {
self.build_state_snapshot()
}
pub fn build_terminal_response(&self) -> serde_json::Value {
self.build_state_snapshot()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::OsString;
use std::sync::{Mutex, MutexGuard, OnceLock};
struct StreamHistoryCapacityEnvGuard {
previous: Option<OsString>,
_lock: MutexGuard<'static, ()>,
}
impl Drop for StreamHistoryCapacityEnvGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => unsafe { std::env::set_var(STREAM_HISTORY_CAPACITY_ENV, value) },
None => unsafe { std::env::remove_var(STREAM_HISTORY_CAPACITY_ENV) },
}
}
}
fn set_stream_history_capacity(value: Option<&str>) -> StreamHistoryCapacityEnvGuard {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
let lock = LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
let previous = std::env::var_os(STREAM_HISTORY_CAPACITY_ENV);
match value {
Some(value) => unsafe { std::env::set_var(STREAM_HISTORY_CAPACITY_ENV, value) },
None => unsafe { std::env::remove_var(STREAM_HISTORY_CAPACITY_ENV) },
}
StreamHistoryCapacityEnvGuard {
previous,
_lock: lock,
}
}
#[test]
fn status_is_terminal() {
assert!(!PredictionStatus::Starting.is_terminal());
assert!(!PredictionStatus::Processing.is_terminal());
assert!(PredictionStatus::Succeeded.is_terminal());
assert!(PredictionStatus::Failed.is_terminal());
assert!(PredictionStatus::Canceled.is_terminal());
}
#[test]
fn new_starts_in_starting_status() {
let pred = Prediction::new("test".to_string(), None);
assert_eq!(pred.status(), PredictionStatus::Starting);
assert_eq!(pred.id(), "test");
}
#[test]
fn set_succeeded() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("hello")));
assert_eq!(pred.status(), PredictionStatus::Succeeded);
}
#[test]
fn set_failed() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_failed("something went wrong".to_string());
assert_eq!(pred.status(), PredictionStatus::Failed);
}
#[test]
fn set_canceled() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_canceled();
assert_eq!(pred.status(), PredictionStatus::Canceled);
}
#[test]
fn cancel_token_works() {
let pred = Prediction::new("test".to_string(), None);
let token = pred.cancel_token();
assert!(!pred.is_canceled());
token.cancel();
assert!(pred.is_canceled());
}
#[test]
fn elapsed_time_increases() {
let pred = Prediction::new("test".to_string(), None);
let t1 = pred.elapsed();
std::thread::sleep(std::time::Duration::from_millis(10));
let t2 = pred.elapsed();
assert!(t2 > t1);
}
#[test]
fn append_log() {
let mut pred = Prediction::new("test".to_string(), None);
pred.append_log("line 1\n");
pred.append_log("line 2\n");
assert_eq!(pred.logs(), "line 1\nline 2\n");
}
#[test]
fn append_output() {
let mut pred = Prediction::new("test".to_string(), None);
pred.append_output(serde_json::json!("chunk1"));
pred.append_output(serde_json::json!("chunk2"));
assert_eq!(pred.outputs().len(), 2);
}
#[tokio::test]
async fn prediction_stream_emits_start_output_log_and_completed() {
let mut prediction = Prediction::new("pred_stream".to_string(), None);
let mut rx = prediction.subscribe_stream();
prediction.set_processing();
prediction.append_output_chunk(serde_json::json!("hello"), 0);
prediction.append_log("loading\n");
prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")]));
let start = rx.recv().await.unwrap();
assert_eq!(start.event_name(), "start");
assert_eq!(
start.json_data(),
serde_json::json!({"id":"pred_stream","status":"processing"})
);
let output = rx.recv().await.unwrap();
assert_eq!(output.event_name(), "output");
assert_eq!(
output.json_data(),
serde_json::json!({"chunk":"hello","index":0})
);
let log = rx.recv().await.unwrap();
assert_eq!(log.event_name(), "log");
assert_eq!(
log.json_data(),
serde_json::json!({"source":"stdout","data":"loading\n"})
);
let completed = rx.recv().await.unwrap();
assert_eq!(completed.event_name(), "completed");
assert_eq!(completed.json_data()["id"], "pred_stream");
assert_eq!(completed.json_data()["status"], "succeeded");
assert_eq!(
completed.json_data()["output"],
serde_json::json!(["hello"])
);
}
#[tokio::test]
async fn prediction_stream_emits_metric_event() {
let mut prediction = Prediction::new("pred_metric".to_string(), None);
let mut rx = prediction.subscribe_stream();
prediction.set_metric(
"tokens".to_string(),
serde_json::json!(1),
MetricMode::Increment,
);
let event = rx.recv().await.unwrap();
assert_eq!(event.event_name(), "metric");
assert_eq!(
event.json_data(),
serde_json::json!({
"name":"tokens",
"value":1,
"mode":"increment"
})
);
}
#[tokio::test]
async fn prediction_stream_preserves_log_source() {
let mut prediction = Prediction::new("pred_log_source".to_string(), None);
let mut rx = prediction.subscribe_stream();
prediction.append_log_source(crate::bridge::protocol::LogSource::Stderr, "warning\n");
let event = rx.recv().await.unwrap();
assert_eq!(event.event_name(), "log");
assert_eq!(
event.json_data(),
serde_json::json!({"source":"stderr","data":"warning\n"})
);
}
#[tokio::test]
async fn prediction_stream_replay_includes_already_emitted_events() {
let _guard = set_stream_history_capacity(None);
let mut prediction = Prediction::new("pred_replay".to_string(), None);
prediction.set_processing();
prediction.append_output_chunk(serde_json::json!("hello"), 0);
prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")]));
let replay = prediction.subscribe_stream_replay();
let events: Vec<&str> = replay
.replay
.iter()
.map(|event| event.event_name())
.collect();
assert_eq!(events, vec!["start", "output", "completed"]);
assert_eq!(replay.skipped, 0);
assert_eq!(
replay.replay[1].json_data(),
serde_json::json!({"chunk":"hello","index":0})
);
assert_eq!(replay.replay[2].json_data()["status"], "succeeded");
}
#[tokio::test]
async fn prediction_stream_replay_is_bounded_to_recent_events() {
let _guard = set_stream_history_capacity(None);
let mut prediction = Prediction::new("pred_replay_bounded".to_string(), None);
prediction.set_processing();
for index in 0..1100 {
prediction.append_output_chunk(serde_json::json!(index), index);
}
let replay = prediction.subscribe_stream_replay();
assert_eq!(replay.replay.len(), DEFAULT_STREAM_HISTORY_CAPACITY);
assert_eq!(replay.skipped, 77);
assert_eq!(
replay.replay[0].json_data(),
serde_json::json!({"chunk":76,"index":76})
);
assert_eq!(
replay.replay[DEFAULT_STREAM_HISTORY_CAPACITY - 1].json_data(),
serde_json::json!({"chunk":1099,"index":1099})
);
}
#[tokio::test]
async fn prediction_stream_replay_uses_configured_history_capacity() {
let _guard = set_stream_history_capacity(Some("2"));
let mut prediction = Prediction::new("pred_replay_configured".to_string(), None);
prediction.append_output_chunk(serde_json::json!(0), 0);
prediction.append_output_chunk(serde_json::json!(1), 1);
prediction.append_output_chunk(serde_json::json!(2), 2);
let replay = prediction.subscribe_stream_replay();
assert_eq!(replay.replay.len(), 2);
assert_eq!(replay.skipped, 1);
assert_eq!(
replay.replay[0].json_data(),
serde_json::json!({"chunk":1,"index":1})
);
assert_eq!(
replay.replay[1].json_data(),
serde_json::json!({"chunk":2,"index":2})
);
}
#[tokio::test]
async fn prediction_stream_replay_can_be_disabled_with_zero_capacity() {
let _guard = set_stream_history_capacity(Some("0"));
let mut prediction = Prediction::new("pred_replay_disabled".to_string(), None);
prediction.append_output_chunk(serde_json::json!(0), 0);
prediction.append_output_chunk(serde_json::json!(1), 1);
let replay = prediction.subscribe_stream_replay();
assert!(replay.replay.is_empty());
assert_eq!(replay.skipped, 0);
}
#[tokio::test]
async fn prediction_stream_replay_uses_default_for_invalid_capacity() {
let _guard = set_stream_history_capacity(Some("nope"));
let mut prediction = Prediction::new("pred_replay_invalid".to_string(), None);
prediction.append_output_chunk(serde_json::json!(0), 0);
let replay = prediction.subscribe_stream_replay();
assert_eq!(replay.replay.len(), 1);
assert_eq!(replay.skipped, 0);
}
#[tokio::test]
async fn wait_returns_immediately_if_terminal() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("done")));
pred.wait().await;
assert_eq!(pred.status(), PredictionStatus::Succeeded);
}
#[test]
fn prediction_output_single() {
let output = PredictionOutput::Single(serde_json::json!("hello"));
assert!(!output.is_stream());
assert_eq!(output.into_values(), vec![serde_json::json!("hello")]);
}
#[test]
fn prediction_output_stream() {
let output = PredictionOutput::Stream(vec![serde_json::json!("a"), serde_json::json!("b")]);
assert!(output.is_stream());
}
#[test]
fn metric_replace_sets_value() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace);
assert_eq!(pred.metrics()["temp"], serde_json::json!(0.7));
}
#[test]
fn metric_replace_overwrites() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace);
pred.set_metric("temp".into(), serde_json::json!(0.9), MetricMode::Replace);
assert_eq!(pred.metrics()["temp"], serde_json::json!(0.9));
}
#[test]
fn metric_replace_null_deletes() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace);
pred.set_metric("temp".into(), serde_json::Value::Null, MetricMode::Replace);
assert!(!pred.metrics().contains_key("temp"));
}
#[test]
fn metric_increment_integers() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("count".into(), serde_json::json!(1), MetricMode::Increment);
pred.set_metric("count".into(), serde_json::json!(3), MetricMode::Increment);
assert_eq!(pred.metrics()["count"], serde_json::json!(4));
}
#[test]
fn metric_increment_floats() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"score".into(),
serde_json::json!(1.5),
MetricMode::Increment,
);
pred.set_metric(
"score".into(),
serde_json::json!(2.5),
MetricMode::Increment,
);
assert_eq!(pred.metrics()["score"], serde_json::json!(4.0));
}
#[test]
fn metric_increment_creates_from_zero() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("count".into(), serde_json::json!(5), MetricMode::Increment);
assert_eq!(pred.metrics()["count"], serde_json::json!(5));
}
#[test]
fn metric_append_creates_array() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"logprobs".into(),
serde_json::json!(-1.2),
MetricMode::Append,
);
pred.set_metric(
"logprobs".into(),
serde_json::json!(-0.3),
MetricMode::Append,
);
assert_eq!(pred.metrics()["logprobs"], serde_json::json!([-1.2, -0.3]));
}
#[test]
fn metric_append_to_non_array_wraps() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("val".into(), serde_json::json!(1), MetricMode::Replace);
pred.set_metric("val".into(), serde_json::json!(2), MetricMode::Append);
assert_eq!(pred.metrics()["val"], serde_json::json!([1, 2]));
}
#[test]
fn metric_dotpath_creates_nested() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"timing.preprocess".into(),
serde_json::json!(0.1),
MetricMode::Replace,
);
assert_eq!(
pred.metrics()["timing"],
serde_json::json!({"preprocess": 0.1})
);
}
#[test]
fn metric_dotpath_deep() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("a.b.c".into(), serde_json::json!(42), MetricMode::Replace);
assert_eq!(pred.metrics()["a"], serde_json::json!({"b": {"c": 42}}));
}
#[test]
fn metric_dotpath_multiple_leaves() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"timing.preprocess".into(),
serde_json::json!(0.1),
MetricMode::Replace,
);
pred.set_metric(
"timing.inference".into(),
serde_json::json!(0.8),
MetricMode::Replace,
);
assert_eq!(
pred.metrics()["timing"],
serde_json::json!({"preprocess": 0.1, "inference": 0.8})
);
}
#[test]
fn metric_dotpath_delete_leaf() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"timing.preprocess".into(),
serde_json::json!(0.1),
MetricMode::Replace,
);
pred.set_metric(
"timing.preprocess".into(),
serde_json::Value::Null,
MetricMode::Replace,
);
assert_eq!(pred.metrics()["timing"], serde_json::json!({}));
}
#[test]
fn metric_dotpath_increment() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"stats.tokens".into(),
serde_json::json!(10),
MetricMode::Increment,
);
pred.set_metric(
"stats.tokens".into(),
serde_json::json!(5),
MetricMode::Increment,
);
assert_eq!(pred.metrics()["stats"], serde_json::json!({"tokens": 15}));
}
#[test]
fn metric_complex_values() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"config".into(),
serde_json::json!({"layers": 12, "heads": 8}),
MetricMode::Replace,
);
pred.set_metric(
"scores".into(),
serde_json::json!([0.9, 0.8, 0.7]),
MetricMode::Replace,
);
assert_eq!(
pred.metrics()["config"],
serde_json::json!({"layers": 12, "heads": 8})
);
assert_eq!(pred.metrics()["scores"], serde_json::json!([0.9, 0.8, 0.7]));
}
#[test]
fn terminal_snapshot_merges_metrics_with_predict_time() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace);
pred.set_metric("count".into(), serde_json::json!(42), MetricMode::Replace);
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("ok")));
let snapshot = pred.build_state_snapshot();
let metrics = snapshot["metrics"].as_object().unwrap();
assert_eq!(metrics["temp"], serde_json::json!(0.7));
assert_eq!(metrics["count"], serde_json::json!(42));
assert!(metrics.contains_key("predict_time"));
}
#[test]
fn terminal_snapshot_predict_time_overrides_user() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(
"predict_time".into(),
serde_json::json!(999.0),
MetricMode::Replace,
);
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("ok")));
let snapshot = pred.build_state_snapshot();
let metrics = snapshot["metrics"].as_object().unwrap();
assert_ne!(metrics["predict_time"], serde_json::json!(999.0));
}
#[test]
fn terminal_state_guard_set_failed_after_succeeded() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("ok")));
pred.set_failed("Slot dropped unexpectedly".to_string());
assert_eq!(pred.status(), PredictionStatus::Succeeded);
assert!(pred.error().is_none());
}
#[test]
fn terminal_state_guard_set_succeeded_after_failed() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_failed("error".to_string());
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("late")));
assert_eq!(pred.status(), PredictionStatus::Failed);
assert_eq!(pred.error(), Some("error"));
}
#[test]
fn terminal_state_guard_set_canceled_after_succeeded() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_succeeded(PredictionOutput::Single(serde_json::json!("done")));
pred.set_canceled();
assert_eq!(pred.status(), PredictionStatus::Succeeded);
}
#[test]
fn metric_empty_key_ignored() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("".into(), serde_json::json!(1), MetricMode::Replace);
assert!(pred.metrics().is_empty());
}
#[test]
fn metric_trailing_dot_ignored() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("a.".into(), serde_json::json!(1), MetricMode::Replace);
assert!(pred.metrics().is_empty());
}
#[test]
fn metric_leading_dot_ignored() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric(".b".into(), serde_json::json!(1), MetricMode::Replace);
assert!(pred.metrics().is_empty());
}
#[test]
fn metric_double_dot_ignored() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("a..b".into(), serde_json::json!(1), MetricMode::Replace);
assert!(pred.metrics().is_empty());
}
#[test]
fn snapshot_includes_empty_metrics_for_non_terminal() {
let pred = Prediction::new("test".to_string(), None);
let snapshot = pred.build_state_snapshot();
assert!(snapshot["metrics"].is_object());
assert!(snapshot["metrics"].as_object().unwrap().is_empty());
}
#[test]
fn snapshot_includes_predict_time_on_failed() {
let mut pred = Prediction::new("test".to_string(), None);
pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace);
pred.set_failed("oops".to_string());
let snapshot = pred.build_state_snapshot();
let metrics = snapshot["metrics"].as_object().unwrap();
assert_eq!(metrics["temp"], serde_json::json!(0.7));
assert!(metrics.contains_key("predict_time"));
}
}