use anyhow::Result;
use async_trait::async_trait;
use koda_core::{
engine::{EngineCommand, EngineEvent, event::TurnEndReason},
persistence::Persistence,
providers::{LlmResponse, ModelInfo},
session::KodaSession,
tools::ToolRegistry,
trust::TrustMode,
};
use koda_test_utils::{
ChatMessage, Env, LlmProvider, MockProvider, MockResponse, TestSink, ToolDefinition,
};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
async fn make_session(
env: &Env,
provider: Box<dyn LlmProvider>,
) -> (KodaSession, CancellationToken) {
let cancel = CancellationToken::new();
let tools = ToolRegistry::new(env.root.clone(), env.config.max_context_tokens);
let agent = Arc::new(koda_core::agent::KodaAgent {
project_root: env.root.clone(),
tools,
tool_defs: ToolRegistry::new(env.root.clone(), env.config.max_context_tokens)
.get_definitions(&[], &[]),
system_prompt: "You are a test assistant.".to_string(),
});
agent
.tools
.set_session(Arc::new(env.db.clone()), env.session_id.clone());
let file_tracker =
koda_core::file_tracker::FileTracker::new(&env.session_id, env.db.clone()).await;
let session = KodaSession {
id: env.session_id.clone(),
agent,
db: env.db.clone(),
provider,
mode: TrustMode::Auto,
cancel: cancel.clone(),
file_tracker,
title_set: false,
};
(session, cancel)
}
#[tokio::test]
async fn session_run_turn_emits_turn_start_and_end() {
let env = Env::new().await;
env.insert_user_message("say hello").await;
let provider = Box::new(MockProvider::new(vec![MockResponse::Text(
"Hello!".to_string(),
)]));
let (mut session, _cancel) = make_session(&env, provider).await;
let sink = TestSink::new();
let (_, mut cmd_rx) = mpsc::channel::<EngineCommand>(1);
let result = session
.run_turn(&env.config, None, &sink, &mut cmd_rx)
.await;
assert!(
result.is_ok(),
"run_turn should succeed: {:?}",
result.err()
);
let events = sink.events();
let first = events.first().expect("expected at least one event");
assert!(
matches!(first, EngineEvent::TurnStart { .. }),
"first event must be TurnStart, got: {first:?}"
);
let last = events.last().expect("expected at least one event");
assert!(
matches!(last, EngineEvent::TurnEnd { .. }),
"last event must be TurnEnd, got: {last:?}"
);
if let EngineEvent::TurnEnd { reason, .. } = last {
assert_eq!(
*reason,
TurnEndReason::Complete,
"TurnEnd reason should be Complete after successful turn"
);
}
let start_id = if let EngineEvent::TurnStart { turn_id } = first {
turn_id.clone()
} else {
unreachable!()
};
let end_id = if let EngineEvent::TurnEnd { turn_id, .. } = last {
turn_id.clone()
} else {
unreachable!()
};
assert_eq!(
start_id, end_id,
"TurnStart and TurnEnd must share the same turn_id"
);
}
#[tokio::test]
async fn session_cancellation_produces_turn_end_cancelled() {
let env = Env::new().await;
env.insert_user_message("hello").await;
struct HangingProvider;
#[async_trait]
impl LlmProvider for HangingProvider {
async fn chat(
&self,
_: &[ChatMessage],
_: &[ToolDefinition],
_: &koda_core::config::ModelSettings,
) -> Result<LlmResponse> {
unreachable!()
}
async fn chat_stream(
&self,
_: &[ChatMessage],
_: &[ToolDefinition],
_: &koda_core::config::ModelSettings,
) -> Result<koda_core::providers::stream_collector::SseCollector> {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
unreachable!()
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
Ok(vec![])
}
fn provider_name(&self) -> &str {
"hanging"
}
}
let (mut session, cancel) = make_session(&env, Box::new(HangingProvider)).await;
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
cancel_clone.cancel();
});
let sink = TestSink::new();
let (_, mut cmd_rx) = mpsc::channel::<EngineCommand>(1);
let start = std::time::Instant::now();
let result = session
.run_turn(&env.config, None, &sink, &mut cmd_rx)
.await;
let elapsed = start.elapsed();
assert!(
result.is_ok(),
"cancellation should be graceful, not an error"
);
assert!(
elapsed < std::time::Duration::from_secs(2),
"cancellation should unblock quickly, took {elapsed:?}"
);
let events = sink.events();
let turn_end_reason = events.iter().find_map(|e| {
if let EngineEvent::TurnEnd { reason, .. } = e {
Some(reason.clone())
} else {
None
}
});
assert!(
turn_end_reason.is_some(),
"expected a TurnEnd event after cancellation, got: {events:?}"
);
assert_eq!(
turn_end_reason.unwrap(),
TurnEndReason::Cancelled,
"TurnEnd reason must be Cancelled when the token is cancelled"
);
}
#[tokio::test]
async fn session_persists_messages_across_two_turns() {
let env = Env::new().await;
env.insert_user_message("first question").await;
let provider1 = Box::new(MockProvider::new(vec![MockResponse::Text(
"first answer".to_string(),
)]));
let (mut session1, _cancel1) = make_session(&env, provider1).await;
let sink1 = TestSink::new();
let (_, mut cmd_rx1) = mpsc::channel::<EngineCommand>(1);
session1
.run_turn(&env.config, None, &sink1, &mut cmd_rx1)
.await
.expect("turn 1 should succeed");
assert!(
sink1.events().iter().any(|e| matches!(
e,
EngineEvent::TurnEnd {
reason: TurnEndReason::Complete,
..
}
)),
"turn 1 should end with Complete"
);
env.insert_user_message("second question").await;
let provider2 = Box::new(MockProvider::new(vec![MockResponse::Text(
"second answer".to_string(),
)]));
let (mut session2, _cancel2) = make_session(&env, provider2).await;
let sink2 = TestSink::new();
let (_, mut cmd_rx2) = mpsc::channel::<EngineCommand>(1);
session2
.run_turn(&env.config, None, &sink2, &mut cmd_rx2)
.await
.expect("turn 2 should succeed");
let messages: Vec<koda_core::persistence::Message> =
env.db.load_context(&env.session_id).await.unwrap();
let contents: Vec<String> = messages
.iter()
.filter_map(|m: &koda_core::persistence::Message| m.content.clone())
.collect();
assert!(
contents
.iter()
.any(|c: &String| c.contains("first question")),
"DB should contain first user message"
);
assert!(
contents.iter().any(|c: &String| c.contains("first answer")),
"DB should contain first assistant response"
);
assert!(
contents
.iter()
.any(|c: &String| c.contains("second question")),
"DB should contain second user message"
);
assert!(
contents
.iter()
.any(|c: &String| c.contains("second answer")),
"DB should contain second assistant response"
);
}