use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use crate::agent::runner::AgentResponse;
use crate::agent::Agent;
use crate::checkpoint::{CheckpointState, ErasedCheckpoint};
use crate::error::{DaimonError, Result};
use crate::hooks::AgentState;
use crate::model::types::{ChatRequest, Message, Usage};
impl Agent {
#[tracing::instrument(skip_all, fields(run_id = %run_id, input_len = input.len()))]
pub async fn prompt_resumable(
&self,
input: &str,
run_id: &str,
checkpoint: &Arc<dyn ErasedCheckpoint>,
) -> Result<AgentResponse> {
self.prompt_resumable_with_cancellation(
input,
run_id,
checkpoint,
&CancellationToken::new(),
)
.await
}
#[tracing::instrument(skip_all, fields(run_id = %run_id, input_len = input.len()))]
pub async fn prompt_resumable_with_cancellation(
&self,
input: &str,
run_id: &str,
checkpoint: &Arc<dyn ErasedCheckpoint>,
cancel: &CancellationToken,
) -> Result<AgentResponse> {
let existing = checkpoint.load_erased(run_id).await?;
let (mut messages, start_iteration, mut total_usage) = if let Some(cp) = existing {
if cp.completed {
tracing::info!(run_id, "checkpoint already completed, replaying result");
let final_text = cp
.messages
.last()
.and_then(|m| m.content.clone())
.unwrap_or_default();
return Ok(AgentResponse {
messages: cp.messages,
final_text,
iterations: cp.iteration,
usage: Usage::default(),
cost: 0.0,
});
}
tracing::info!(
run_id,
iteration = cp.iteration,
messages = cp.messages.len(),
"resuming from checkpoint"
);
(cp.messages, cp.iteration, Usage::default())
} else {
let history = self.memory.get_messages_erased().await?;
let mut msgs = Vec::new();
if let Some(system) = &self.system_prompt {
msgs.push(Message::system(system));
}
msgs.extend(history);
msgs.push(Message::user(input));
self.memory.add_message_erased(Message::user(input)).await?;
(msgs, 0, Usage::default())
};
let mut tool_specs_vec: Vec<crate::model::types::ToolSpec> =
self.tools.tool_specs().to_vec();
let mut iteration = start_iteration;
loop {
if cancel.is_cancelled() {
checkpoint
.save_erased(&CheckpointState::new(
run_id,
messages.clone(),
iteration,
))
.await?;
return Err(DaimonError::Cancelled);
}
iteration += 1;
let state = AgentState {
iteration,
max_iterations: self.max_iterations,
};
self.hooks.on_iteration_start_erased(&state).await?;
let mut request = ChatRequest {
messages: std::mem::take(&mut messages),
tools: std::mem::take(&mut tool_specs_vec),
temperature: self.temperature,
max_tokens: self.max_tokens,
};
let result = {
tracing::debug!(iteration, "calling model (resumable)");
self.model
.generate_erased(&request)
.instrument(tracing::info_span!("model_generate", iteration))
.await
};
messages = std::mem::take(&mut request.messages);
tool_specs_vec = std::mem::take(&mut request.tools);
let response = result?;
if let Some(ref usage) = response.usage {
total_usage.accumulate(usage);
}
self.hooks.on_model_response_erased(&response).await?;
if response.has_tool_calls() {
let tool_calls = response.tool_calls().to_vec();
let assistant_msg =
Message::assistant_with_tool_calls(tool_calls.clone());
messages.push(assistant_msg.clone());
self.memory.add_message_erased(assistant_msg).await?;
let tool_results = self.execute_tools_parallel(&tool_calls).await;
for (call, tool_result) in tool_calls.iter().zip(tool_results) {
let result_msg =
Message::tool_result(&call.id, &tool_result.content);
messages.push(result_msg.clone());
self.memory.add_message_erased(result_msg).await?;
}
checkpoint
.save_erased(&CheckpointState::new(
run_id,
messages.clone(),
iteration,
))
.await?;
self.hooks.on_iteration_end_erased(&state).await?;
if iteration >= self.max_iterations {
return Err(DaimonError::MaxIterations(self.max_iterations));
}
continue;
}
let final_text = response.text().to_string();
messages.push(response.message.clone());
self.memory.add_message_erased(response.message).await?;
let completed_state = CheckpointState::new(
run_id,
messages.clone(),
iteration,
)
.mark_completed();
checkpoint.save_erased(&completed_state).await?;
self.hooks.on_iteration_end_erased(&state).await?;
return Ok(AgentResponse {
messages,
final_text,
iterations: iteration,
usage: total_usage,
cost: 0.0,
});
}
}
#[tracing::instrument(skip_all, fields(run_id = %run_id))]
pub async fn replay(
&self,
run_id: &str,
checkpoint: &Arc<dyn ErasedCheckpoint>,
from_iteration: Option<usize>,
) -> Result<AgentResponse> {
let state = checkpoint
.load_erased(run_id)
.await?
.ok_or_else(|| DaimonError::Other(format!("no checkpoint for run '{run_id}'")))?;
let messages = if let Some(target) = from_iteration {
truncate_to_iteration(state.messages, target)
} else {
state.messages
};
tracing::info!(
run_id,
message_count = messages.len(),
from_iteration,
"replaying from checkpoint"
);
self.run_react_loop(messages, &CancellationToken::new())
.await
}
}
fn truncate_to_iteration(messages: Vec<Message>, target_iteration: usize) -> Vec<Message> {
let mut result = Vec::new();
let mut assistant_count = 0;
for msg in messages {
result.push(msg.clone());
if msg.role == crate::model::types::Role::Assistant {
assistant_count += 1;
if assistant_count >= target_iteration {
break;
}
}
}
result
}