use super::manager::SessionUpdate;
use super::types::{FeedItem, SessionStatus};
use tokio::sync::mpsc;
pub async fn pump_session(
mut rx: mpsc::UnboundedReceiver<SessionUpdate>,
session_key: String,
on_feed: impl Fn(FeedItem),
on_status: impl Fn(SessionStatus),
on_session_id: impl Fn(String),
on_output_file: impl Fn(String),
on_pid: impl Fn(u32),
) {
let mut got_terminal = false;
let mut captured_session_id: Option<String> = None;
while let Some(update) = rx.recv().await {
match update {
SessionUpdate::Feed(item) => {
if let FeedItem::ToolCall { ref name, ref input } = item {
if let Some(path) = extract_output_file(name, input) {
on_output_file(path);
}
}
if let FeedItem::FinalResult { .. } = &item {
got_terminal = true;
}
on_feed(item);
}
SessionUpdate::Status(status) => {
let is_terminal = matches!(
status,
SessionStatus::Completed | SessionStatus::Error(_)
);
on_status(status);
if is_terminal {
got_terminal = true;
break;
}
}
SessionUpdate::SessionId(sid) => {
captured_session_id = Some(sid.clone());
on_session_id(sid);
}
SessionUpdate::ProcessPid(pid) => {
on_pid(pid);
}
}
}
if !got_terminal {
let sid_info = captured_session_id.as_deref().unwrap_or("unknown");
eprintln!(
"[keel:pump] channel closed without terminal status (key={}, session={})",
session_key, sid_info
);
on_status(SessionStatus::Error(
"Session ended unexpectedly".to_string(),
));
}
}
fn extract_output_file(name: &str, input: &serde_json::Value) -> Option<String> {
match name {
"Write" | "Edit" | "MultiEdit" => input
.get("file_path")
.or_else(|| input.get("path"))
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[tokio::test]
async fn pump_delivers_feed_items() {
let (tx, rx) = mpsc::unbounded_channel();
let items: Arc<Mutex<Vec<FeedItem>>> = Arc::new(Mutex::new(Vec::new()));
let items2 = Arc::clone(&items);
tx.send(SessionUpdate::Feed(FeedItem::AssistantText(
"hello".into(),
)))
.unwrap();
tx.send(SessionUpdate::Status(SessionStatus::Completed))
.unwrap();
drop(tx);
pump_session(
rx,
"test".into(),
move |item| items2.lock().unwrap().push(item),
|_| {},
|_| {},
|_| {},
|_| {},
)
.await;
let captured = items.lock().unwrap();
assert_eq!(captured.len(), 1);
assert!(matches!(&captured[0], FeedItem::AssistantText(t) if t == "hello"));
}
#[tokio::test]
async fn pump_calls_on_session_id() {
let (tx, rx) = mpsc::unbounded_channel();
let sid: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let sid2 = Arc::clone(&sid);
tx.send(SessionUpdate::SessionId("s-123".into())).unwrap();
tx.send(SessionUpdate::Status(SessionStatus::Completed))
.unwrap();
drop(tx);
pump_session(
rx,
"test".into(),
|_| {},
|_| {},
move |s| *sid2.lock().unwrap() = Some(s),
|_| {},
|_| {},
)
.await;
assert_eq!(*sid.lock().unwrap(), Some("s-123".to_string()));
}
#[tokio::test]
async fn pump_detects_output_files() {
let (tx, rx) = mpsc::unbounded_channel();
let files: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let files2 = Arc::clone(&files);
tx.send(SessionUpdate::Feed(FeedItem::ToolCall {
name: "Write".into(),
input: serde_json::json!({"file_path": "/tmp/out.txt"}),
}))
.unwrap();
tx.send(SessionUpdate::Status(SessionStatus::Completed))
.unwrap();
drop(tx);
pump_session(
rx,
"test".into(),
|_| {},
|_| {},
|_| {},
move |f| files2.lock().unwrap().push(f),
|_| {},
)
.await;
let captured = files.lock().unwrap();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0], "/tmp/out.txt");
}
#[tokio::test]
async fn pump_synthesises_error_on_unexpected_close() {
let (tx, rx) = mpsc::unbounded_channel();
let statuses: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let statuses2 = Arc::clone(&statuses);
drop(tx);
pump_session(
rx,
"test".into(),
|_| {},
move |s| {
let label = match s {
SessionStatus::Error(e) => format!("error: {e}"),
other => format!("{other:?}"),
};
statuses2.lock().unwrap().push(label);
},
|_| {},
|_| {},
|_| {},
)
.await;
let captured = statuses.lock().unwrap();
assert_eq!(captured.len(), 1);
assert!(captured[0].contains("unexpectedly"));
}
#[test]
fn extract_output_file_from_write() {
let input = serde_json::json!({"file_path": "/tmp/out.txt"});
assert_eq!(
extract_output_file("Write", &input),
Some("/tmp/out.txt".to_string())
);
}
#[test]
fn extract_output_file_from_edit() {
let input = serde_json::json!({"file_path": "/tmp/edit.rs"});
assert_eq!(
extract_output_file("Edit", &input),
Some("/tmp/edit.rs".to_string())
);
}
#[test]
fn extract_output_file_from_other_tool() {
let input = serde_json::json!({"path": "/tmp/foo"});
assert_eq!(extract_output_file("Read", &input), None);
}
}