use super::config::*;
use super::helpers::default_convert_to_llm;
use crate::provider::{ProviderError, ProviderRegistry, StreamConfig, StreamEvent, StreamProvider};
use crate::types::*;
use chrono::Utc;
use tokio::sync::mpsc;
use tracing::warn;
pub(super) fn derive_provenance(messages: &[AgentMessage]) -> Vec<BlockProvenance> {
let mut per_turn_counter: std::collections::HashMap<u32, usize> =
std::collections::HashMap::new();
let mut out: Vec<BlockProvenance> = Vec::new();
for am in messages {
let AgentMessage::Llm(lm) = am else {
continue;
};
if let Some(ref hint) = lm.provenance_hint {
out.push((**hint).clone());
continue;
}
let prov = match (&lm.turn_id, &lm.message) {
(Some(tid), Message::User { .. }) => {
let mi = per_turn_counter.entry(tid.turn_index).or_insert(0);
let val = *mi;
*mi += 1;
BlockProvenance::LoopTurn {
turn_index: tid.turn_index as usize,
role: ProvenanceRole::UserMessage,
message_index: val,
}
}
(Some(tid), Message::Assistant { content, .. }) => {
let has_tool_call = content
.iter()
.any(|c| matches!(c, Content::ToolCall { .. }));
let role = if has_tool_call {
ProvenanceRole::ToolCallRequest
} else {
ProvenanceRole::AssistantResponse
};
let mi = per_turn_counter.entry(tid.turn_index).or_insert(0);
let val = *mi;
*mi += 1;
BlockProvenance::LoopTurn {
turn_index: tid.turn_index as usize,
role,
message_index: val,
}
}
(Some(tid), Message::ToolResult { .. }) => {
let mi = per_turn_counter.entry(tid.turn_index).or_insert(0);
let val = *mi;
*mi += 1;
BlockProvenance::LoopTurn {
turn_index: tid.turn_index as usize,
role: ProvenanceRole::ToolCallResult,
message_index: val,
}
}
(None, Message::User { .. }) => {
if out.iter().any(|p| matches!(p, BlockProvenance::Steering)) {
BlockProvenance::FollowUp
} else {
BlockProvenance::Steering
}
}
(None, _) => BlockProvenance::Unknown,
};
out.push(prov);
}
out
}
pub(super) async fn stream_assistant_response(
context: &AgentContext, config: &AgentLoopConfig, tx: &mpsc::UnboundedSender<AgentEvent>, cancel: &tokio_util::sync::CancellationToken, loop_id: &str,
turn_index: u32, ) -> Message {
let base_messages = if context.active_node_id.is_some() {
context.build_trunk_context_with_policy(&config.revert_render_policy, turn_index)
} else {
context.build_working_context()
};
let messages = if let Some(transform) = &config.transform_context {
transform(base_messages)
} else {
base_messages
};
let convert = config.convert_to_llm.as_ref();
let llm_messages = match convert {
Some(f) => f(&messages),
None => default_convert_to_llm(&messages), };
let tool_defs: Vec<crate::provider::ToolDefinition> = context
.tools
.iter()
.map(|t| crate::provider::ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters_schema(),
})
.collect();
{
let provenance = derive_provenance(&messages);
let payload = AnnotatedRequestPayload {
system_prompt: context.system_prompt.clone(),
messages: llm_messages.clone(),
provenance,
tools: tool_defs.clone(),
model_id: config.model_config.id.clone(),
thinking_level: config.thinking_level,
max_tokens: config.max_tokens,
temperature: config.temperature,
response_format: config.response_format.clone(),
};
tx.send(AgentEvent::TurnRequest {
loop_id: loop_id.to_string(),
turn_index,
payload,
timestamp: Utc::now(),
})
.ok();
}
let registry = ProviderRegistry::default();
let provider: &dyn StreamProvider = match config.provider_override.as_deref() {
Some(p) => p,
None => match registry.get(&config.model_config.api) {
Some(p) => p,
None => {
return Message::Assistant {
content: vec![Content::Text {
text: String::new(),
}],
stop_reason: StopReason::Error,
model: config.model_config.id.clone(),
provider: String::new(),
usage: Usage::default(),
timestamp: now_ms(),
error_message: Some(format!(
"No provider registered for protocol: {}",
config.model_config.api
)),
};
}
},
};
let retry = &config.retry_config;
let mut attempt = 0;
let mut auth_refreshed = false;
let (result, mut stream_rx) = loop {
let stream_config = StreamConfig {
model_config: config.model_config.clone(),
system_prompt: context.system_prompt.clone(),
messages: llm_messages.clone(),
tools: tool_defs.clone(),
thinking_level: config.thinking_level,
max_tokens: config.max_tokens,
temperature: config.temperature,
cache_config: config.cache_config.clone(),
response_format: config.response_format.clone(),
};
let (stream_tx, stream_rx) = mpsc::unbounded_channel();
let provider_cancel = cancel.clone();
let result = provider
.stream(stream_config, stream_tx, provider_cancel)
.await;
match &result {
Err(e) if e.is_retryable() && attempt < retry.max_retries && !cancel.is_cancelled() => {
attempt += 1;
let delay = e
.retry_after()
.unwrap_or_else(|| retry.delay_for_attempt(attempt));
crate::provider::retry::log_retry(attempt, retry.max_retries, &delay, e);
tokio::time::sleep(delay).await;
continue; }
Err(ProviderError::Auth(_))
if config.model_config.credentials.is_some()
&& !auth_refreshed
&& !cancel.is_cancelled() =>
{
auth_refreshed = true;
tracing::warn!(
"Provider returned Auth error; refreshing credentials and retrying once."
);
if let Err(e) = config.model_config.invalidate_credentials().await {
tracing::warn!("CredentialProvider::invalidate failed: {}", e);
}
continue;
}
_ => break (result, stream_rx), }
};
let mut partial_message: Option<AgentMessage> = None;
while let Ok(event) = stream_rx.try_recv() {
match &event {
StreamEvent::Start => {
let placeholder = AgentMessage::Llm(LlmMessage::new(Message::Assistant {
content: Vec::new(),
stop_reason: StopReason::Stop,
model: config.model_config.id.clone(),
provider: String::new(),
usage: Usage::default(),
timestamp: now_ms(),
error_message: None,
}));
partial_message = Some(placeholder.clone());
tx.send(AgentEvent::MessageStart {
loop_id: loop_id.to_string(),
message: placeholder,
})
.ok(); }
StreamEvent::TextDelta { delta, .. } => {
if let Some(ref msg) = partial_message {
tx.send(AgentEvent::MessageUpdate {
loop_id: loop_id.to_string(),
message: msg.clone(),
delta: StreamDelta::Text {
delta: delta.clone(),
},
})
.ok();
}
}
StreamEvent::ThinkingDelta { delta, .. } => {
if let Some(ref msg) = partial_message {
tx.send(AgentEvent::MessageUpdate {
loop_id: loop_id.to_string(),
message: msg.clone(),
delta: StreamDelta::Thinking {
delta: delta.clone(),
},
})
.ok();
}
}
StreamEvent::ToolCallDelta { delta, .. } => {
if let Some(ref msg) = partial_message {
tx.send(AgentEvent::MessageUpdate {
loop_id: loop_id.to_string(),
message: msg.clone(),
delta: StreamDelta::ToolCallDelta {
delta: delta.clone(),
},
})
.ok();
}
}
StreamEvent::Done { message } => {
let am: AgentMessage = message.clone().into();
partial_message = Some(am.clone());
tx.send(AgentEvent::MessageEnd {
loop_id: loop_id.to_string(),
message: am,
})
.ok();
}
StreamEvent::Error { message } => {
let am: AgentMessage = message.clone().into();
if partial_message.is_none() {
tx.send(AgentEvent::MessageStart {
loop_id: loop_id.to_string(),
message: am.clone(),
})
.ok();
}
partial_message = Some(am.clone());
tx.send(AgentEvent::MessageEnd {
loop_id: loop_id.to_string(),
message: am,
})
.ok();
}
_ => {} }
}
match result {
Ok(msg) => msg,
Err(e) => {
warn!("Provider error: {}", e);
Message::Assistant {
content: vec![Content::Text {
text: String::new(), }],
stop_reason: StopReason::Error,
model: config.model_config.id.clone(),
provider: "unknown".into(), usage: Usage::default(),
timestamp: now_ms(),
error_message: Some(e.to_string()), }
}
}
}