use crate::budget::BudgetTracker;
use crate::duplex::{DuplexSession, TurnResult};
use crate::error::Result;
#[derive(Debug)]
pub struct Conversation {
inner: DuplexSession,
state: ConversationState,
}
#[derive(Debug, Default)]
struct ConversationState {
history: Vec<TurnResult>,
cumulative_cost_usd: f64,
cumulative_turns: u32,
budget: Option<BudgetTracker>,
}
impl ConversationState {
fn record(&mut self, turn: TurnResult) {
let cost = turn.total_cost_usd().unwrap_or(0.0);
self.cumulative_cost_usd += cost;
self.cumulative_turns = self.cumulative_turns.saturating_add(1);
if let Some(b) = &self.budget {
b.record(cost);
}
self.history.push(turn);
}
}
impl Conversation {
#[must_use]
pub fn new(session: DuplexSession) -> Self {
Self {
inner: session,
state: ConversationState::default(),
}
}
#[must_use]
pub fn with_budget(mut self, budget: BudgetTracker) -> Self {
self.state.budget = Some(budget);
self
}
#[must_use]
pub fn budget(&self) -> Option<&BudgetTracker> {
self.state.budget.as_ref()
}
pub async fn send(&mut self, prompt: impl Into<String>) -> Result<&TurnResult> {
if let Some(b) = &self.state.budget {
b.check()?;
}
let turn = self.inner.send(prompt).await?;
self.state.record(turn);
Ok(self
.state
.history
.last()
.expect("just-pushed entry must be present"))
}
#[must_use]
pub fn history(&self) -> &[TurnResult] {
&self.state.history
}
#[must_use]
pub fn last(&self) -> Option<&TurnResult> {
self.state.history.last()
}
#[must_use]
pub fn total_cost_usd(&self) -> f64 {
self.state.cumulative_cost_usd
}
#[must_use]
pub fn total_turns(&self) -> u32 {
self.state.cumulative_turns
}
#[must_use]
pub fn session_id(&self) -> Option<&str> {
self.state.history.last().and_then(TurnResult::session_id)
}
#[must_use]
pub fn session(&self) -> &DuplexSession {
&self.inner
}
pub async fn close(self) -> Result<()> {
self.inner.close().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn turn(session_id: &str, cost: f64) -> TurnResult {
TurnResult {
result: json!({
"type": "result",
"result": "ok",
"session_id": session_id,
"total_cost_usd": cost,
}),
events: vec![],
}
}
#[test]
fn record_pushes_turn_and_updates_counters() {
let mut state = ConversationState::default();
state.record(turn("sess-1", 0.05));
assert_eq!(state.history.len(), 1);
assert_eq!(state.cumulative_turns, 1);
assert!((state.cumulative_cost_usd - 0.05).abs() < 1e-9);
assert_eq!(state.history[0].session_id(), Some("sess-1"));
}
#[test]
fn record_accumulates_across_turns() {
let mut state = ConversationState::default();
state.record(turn("sess-1", 0.01));
state.record(turn("sess-1", 0.02));
state.record(turn("sess-1", 0.03));
assert_eq!(state.history.len(), 3);
assert_eq!(state.cumulative_turns, 3);
assert!((state.cumulative_cost_usd - 0.06).abs() < 1e-9);
}
#[test]
fn record_treats_missing_cost_as_zero() {
let mut state = ConversationState::default();
let bare = TurnResult {
result: json!({ "type": "result", "session_id": "sess-1" }),
events: vec![],
};
state.record(bare);
assert_eq!(state.cumulative_turns, 1);
assert_eq!(state.cumulative_cost_usd, 0.0);
}
#[test]
fn record_forwards_cost_to_budget() {
let budget = BudgetTracker::builder().build();
let mut state = ConversationState {
budget: Some(budget.clone()),
..Default::default()
};
state.record(turn("sess-1", 0.07));
assert!((budget.total_usd() - 0.07).abs() < 1e-9);
assert!((state.cumulative_cost_usd - 0.07).abs() < 1e-9);
}
#[test]
fn budget_check_blocks_before_send() {
use crate::error::Error;
let budget = BudgetTracker::builder().max_usd(0.10).build();
budget.record(0.15);
match budget.check() {
Err(Error::BudgetExceeded { total_usd, max_usd }) => {
assert!((total_usd - 0.15).abs() < 1e-9);
assert!((max_usd - 0.10).abs() < 1e-9);
}
other => panic!("expected BudgetExceeded, got {other:?}"),
}
}
#[test]
fn last_returns_most_recent() {
let mut state = ConversationState::default();
assert!(state.history.last().is_none());
state.record(turn("sess-1", 0.01));
state.record(turn("sess-1", 0.02));
let last = state.history.last().expect("last entry present");
assert!((last.total_cost_usd().unwrap() - 0.02).abs() < 1e-9);
}
#[test]
fn session_id_pulled_from_last_turn() {
let mut state = ConversationState::default();
assert!(
state
.history
.last()
.and_then(TurnResult::session_id)
.is_none()
);
state.record(turn("sess-A", 0.01));
state.record(turn("sess-B", 0.02));
assert_eq!(
state.history.last().and_then(TurnResult::session_id),
Some("sess-B")
);
}
}