use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use tracing::{info, warn, error};
use crate::channel::{Channel, ChannelKey, ChannelType};
use crate::delivery::{DeliveryStrategy, MessageRef, select_strategy};
use super::acp_streaming::{
CodingAgentUpdate, CodingAgentUpdateStream, StreamingAcpClient,
format_update_for_display,
};
use super::models::{TaskError, TaskRequest, TaskResult};
use super::registry::CodingAgentRegistry;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub update_interval: Duration,
pub show_thoughts: bool,
pub stream_text: bool,
pub text_chunk_size: usize,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
update_interval: Duration::from_millis(500),
show_thoughts: false,
stream_text: true,
text_chunk_size: 100,
}
}
}
pub struct StreamingTaskExecutor {
client: StreamingAcpClient,
registry: Arc<CodingAgentRegistry>,
channel_map: Arc<DashMap<ChannelKey, Arc<dyn Channel>>>,
config: StreamingConfig,
}
impl StreamingTaskExecutor {
pub fn new(
registry: Arc<CodingAgentRegistry>,
channel_map: Arc<DashMap<ChannelKey, Arc<dyn Channel>>>,
) -> Self {
Self {
client: StreamingAcpClient::new(),
registry,
channel_map,
config: StreamingConfig::default(),
}
}
pub fn with_config(
registry: Arc<CodingAgentRegistry>,
channel_map: Arc<DashMap<ChannelKey, Arc<dyn Channel>>>,
config: StreamingConfig,
) -> Self {
Self {
client: StreamingAcpClient::new(),
registry,
channel_map,
config,
}
}
pub async fn execute_with_streaming(
&self,
agent_id: &str,
request: &TaskRequest,
) -> Result<TaskResult, TaskError> {
let agent = self.registry.get_agent(agent_id).ok_or_else(|| {
TaskError::AgentDisconnected {
agent_id: agent_id.to_string(),
}
})?;
let msg_ref = msg_ref_from_request(request);
let channel_key = ChannelKey {
channel_type: msg_ref.channel_type,
account_id: msg_ref.account_id.clone(),
};
let delivery = self.channel_map.get(&channel_key)
.map(|c| select_strategy(c.value().clone(), Some("partial")));
if delivery.is_none() {
warn!(
channel_type = ?msg_ref.channel_type,
account_id = %msg_ref.account_id,
recipient = %msg_ref.recipient_id,
"No delivery channel found — task will execute without live streaming"
);
}
info!(
agent_id = %agent_id,
channel_type = ?msg_ref.channel_type,
recipient = %msg_ref.recipient_id,
has_delivery = delivery.is_some(),
"Starting streaming task execution"
);
let mut stream = self.client
.execute_streaming(agent_id, &agent.config, request)
.await?;
match delivery {
Some(delivery) => self.process_stream(&mut stream, delivery, msg_ref).await,
None => self.process_stream_no_delivery(&mut stream).await,
}
}
pub async fn execute_with_delivery(
&self,
agent_id: &str,
request: &TaskRequest,
delivery: Arc<dyn DeliveryStrategy>,
msg_ref: MessageRef,
) -> Result<TaskResult, TaskError> {
let agent = self.registry.get_agent(agent_id).ok_or_else(|| {
TaskError::AgentDisconnected {
agent_id: agent_id.to_string(),
}
})?;
let mut stream = self.client
.execute_streaming(agent_id, &agent.config, request)
.await?;
self.process_stream(&mut stream, delivery, msg_ref).await
}
async fn process_stream(
&self,
stream: &mut CodingAgentUpdateStream,
delivery: Arc<dyn DeliveryStrategy>,
msg_ref: MessageRef,
) -> Result<TaskResult, TaskError> {
let mut accumulated_text = String::new();
let mut last_update_sent = std::time::Instant::now();
let mut pending_status_updates: Vec<String> = Vec::new();
while let Some(update) = stream.recv().await {
match update {
CodingAgentUpdate::Text(text) => {
accumulated_text.push_str(&text);
if self.config.stream_text &&
accumulated_text.len() >= self.config.text_chunk_size {
let _ = delivery.on_partial(&accumulated_text, &msg_ref).await;
}
}
CodingAgentUpdate::Thought(thought) => {
if self.config.show_thoughts {
let thought_text = format!("💭 _{}_", thought.trim());
pending_status_updates.push(thought_text);
}
}
CodingAgentUpdate::Status(status) => {
if let Some(formatted) = format_update_for_display(&CodingAgentUpdate::Status(status)) {
pending_status_updates.push(formatted);
}
}
CodingAgentUpdate::ToolCallStarted { title } => {
let formatted = format!("🔧 {}", title);
pending_status_updates.push(formatted);
}
CodingAgentUpdate::ToolCallCompleted { title } => {
let formatted = format!("✅ {}", title);
pending_status_updates.push(formatted);
}
CodingAgentUpdate::PermissionRequested { title, approved } => {
let formatted = if approved {
format!("🔓 {}", title)
} else {
format!("🔒 Denied: {}", title)
};
pending_status_updates.push(formatted);
}
CodingAgentUpdate::Done { output, duration, success, error } => {
if success {
let _ = delivery.on_complete(&output, &msg_ref).await;
return Ok(TaskResult {
output,
modified_files: vec![],
duration_ms: duration.as_millis() as u64,
token_usage: None,
});
} else {
let error_msg = error.clone().unwrap_or_else(|| "Unknown error".to_string());
let _ = delivery.on_complete(&format!("❌ Task failed: {}", error_msg), &msg_ref).await;
return Err(TaskError::ExecutionError {
message: error_msg,
partial_output: if output.is_empty() { None } else { Some(output) },
});
}
}
}
if !pending_status_updates.is_empty() &&
last_update_sent.elapsed() >= self.config.update_interval {
let status_block = pending_status_updates.join("\n");
let combined = if accumulated_text.is_empty() {
status_block
} else {
format!("{}\n\n---\n{}", status_block, accumulated_text)
};
let _ = delivery.on_partial(&combined, &msg_ref).await;
pending_status_updates.clear();
last_update_sent = std::time::Instant::now();
}
}
error!("ACP stream ended unexpectedly without Done update");
Err(TaskError::ExecutionError {
message: "Stream ended unexpectedly".to_string(),
partial_output: if accumulated_text.is_empty() { None } else { Some(accumulated_text) },
})
}
async fn process_stream_no_delivery(
&self,
stream: &mut CodingAgentUpdateStream,
) -> Result<TaskResult, TaskError> {
let mut accumulated_text = String::new();
while let Some(update) = stream.recv().await {
match update {
CodingAgentUpdate::Text(text) => {
accumulated_text.push_str(&text);
}
CodingAgentUpdate::Done { output, duration, success, error } => {
if success {
return Ok(TaskResult {
output,
modified_files: vec![],
duration_ms: duration.as_millis() as u64,
token_usage: None,
});
} else {
let error_msg = error.unwrap_or_else(|| "Unknown error".to_string());
return Err(TaskError::ExecutionError {
message: error_msg,
partial_output: if output.is_empty() { None } else { Some(output) },
});
}
}
_ => {}
}
}
error!("ACP stream ended unexpectedly without Done update (no-delivery path)");
Err(TaskError::ExecutionError {
message: "Stream ended unexpectedly".to_string(),
partial_output: if accumulated_text.is_empty() { None } else { Some(accumulated_text) },
})
}
pub fn usage_stats(&self) -> adk_acp::AcpUsageStats {
self.client.usage_stats()
}
}
pub fn msg_ref_from_request(request: &TaskRequest) -> MessageRef {
let channel_type = match request.reply_to.channel_type.to_lowercase().as_str() {
"telegram" => ChannelType::Telegram,
"slack" => ChannelType::Slack,
"discord" => ChannelType::Discord,
"whatsapp" => ChannelType::Whatsapp,
"matrix" => ChannelType::Matrix,
_ => ChannelType::Telegram, };
MessageRef {
channel_type,
account_id: "default".to_string(),
recipient_id: request.reply_to.channel_id.clone(),
message_id: request.reply_to.message_id.clone(),
reply_to: request.reply_to.message_id.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coding_agent::models::{ReplyTarget, TaskTrigger};
use std::path::PathBuf;
fn make_request(description: &str) -> TaskRequest {
TaskRequest {
description: description.to_string(),
trigger: TaskTrigger::ControlPanel {
user_id: "test".to_string(),
},
workspace: Some(PathBuf::from("/tmp/test")),
file_context: None,
reply_to: ReplyTarget {
channel_type: "telegram".to_string(),
channel_id: "123".to_string(),
message_id: Some("456".to_string()),
},
}
}
#[test]
fn test_streaming_config_defaults() {
let config = StreamingConfig::default();
assert_eq!(config.update_interval, Duration::from_millis(500));
assert!(!config.show_thoughts);
assert!(config.stream_text);
assert_eq!(config.text_chunk_size, 100);
}
#[test]
fn test_msg_ref_from_request() {
let request = make_request("test task");
let msg_ref = msg_ref_from_request(&request);
assert_eq!(msg_ref.channel_type, ChannelType::Telegram);
assert_eq!(msg_ref.recipient_id, "123");
assert_eq!(msg_ref.message_id, Some("456".to_string()));
}
#[test]
fn test_msg_ref_channel_type_mapping() {
let mut request = make_request("test");
request.reply_to.channel_type = "slack".to_string();
assert_eq!(msg_ref_from_request(&request).channel_type, ChannelType::Slack);
request.reply_to.channel_type = "discord".to_string();
assert_eq!(msg_ref_from_request(&request).channel_type, ChannelType::Discord);
request.reply_to.channel_type = "TELEGRAM".to_string(); assert_eq!(msg_ref_from_request(&request).channel_type, ChannelType::Telegram);
request.reply_to.channel_type = "unknown".to_string();
assert_eq!(msg_ref_from_request(&request).channel_type, ChannelType::Telegram); }
}