use tracing::{debug, error, warn};
use saorsa_ai::{
CompletionRequest, ContentBlock, ContentDelta, Message, StopReason, StreamEvent,
StreamingProvider,
};
use crate::config::AgentConfig;
use crate::error::{Result, SaorsaAgentError};
use crate::event::{AgentEvent, EventSender, TurnEndReason};
use crate::tool::ToolRegistry;
pub struct AgentLoop {
provider: Box<dyn StreamingProvider>,
config: AgentConfig,
tools: ToolRegistry,
event_tx: EventSender,
messages: Vec<Message>,
}
impl AgentLoop {
pub fn new(
provider: Box<dyn StreamingProvider>,
config: AgentConfig,
tools: ToolRegistry,
event_tx: EventSender,
) -> Self {
Self {
provider,
config,
tools,
event_tx,
messages: Vec::new(),
}
}
pub async fn run(&mut self, user_message: &str) -> Result<String> {
self.messages.push(Message::user(user_message));
let mut turn = 0u32;
let mut final_text = String::new();
loop {
turn += 1;
if turn > self.config.max_turns {
debug!(turn, max = self.config.max_turns, "Max turns reached");
let _ = self
.event_tx
.send(AgentEvent::TurnEnd {
turn,
reason: TurnEndReason::MaxTurns,
})
.await;
break;
}
let _ = self.event_tx.send(AgentEvent::TurnStart { turn }).await;
let request = CompletionRequest::new(
&self.config.model,
self.messages.clone(),
self.config.max_tokens,
)
.system(&self.config.system_prompt)
.tools(self.tools.definitions());
let mut rx = self.provider.stream(request).await?;
let mut text_content = String::new();
let mut tool_calls: Vec<ToolCallInfo> = Vec::new();
let mut stop_reason = None;
while let Some(event) = rx.recv().await {
match event {
Ok(StreamEvent::ContentBlockStart {
content_block: ContentBlock::ToolUse { id, name, .. },
..
}) => {
tool_calls.push(ToolCallInfo {
id,
name,
input_json: String::new(),
});
}
Ok(StreamEvent::ContentBlockDelta {
delta: ContentDelta::TextDelta { text },
..
}) => {
text_content.push_str(&text);
let _ = self.event_tx.send(AgentEvent::TextDelta { text }).await;
}
Ok(StreamEvent::ContentBlockDelta {
delta: ContentDelta::InputJsonDelta { partial_json },
..
}) => {
if let Some(tc) = tool_calls.last_mut() {
tc.input_json.push_str(&partial_json);
}
}
Ok(StreamEvent::ContentBlockDelta {
delta: ContentDelta::ThinkingDelta { text },
..
}) => {
let _ = self.event_tx.send(AgentEvent::ThinkingDelta { text }).await;
}
Ok(StreamEvent::MessageDelta {
stop_reason: sr, ..
}) => {
stop_reason = sr;
}
Ok(StreamEvent::Error { message }) => {
error!(message = %message, "Stream error");
let _ = self
.event_tx
.send(AgentEvent::Error {
message: message.clone(),
})
.await;
return Err(SaorsaAgentError::Internal(message));
}
_ => {}
}
}
if !text_content.is_empty() {
final_text.clone_from(&text_content);
let _ = self
.event_tx
.send(AgentEvent::TextComplete {
text: text_content.clone(),
})
.await;
}
let mut assistant_content: Vec<ContentBlock> = Vec::new();
if !text_content.is_empty() {
assistant_content.push(ContentBlock::Text { text: text_content });
}
let mut parsed_inputs = Vec::with_capacity(tool_calls.len());
for tc in &tool_calls {
let input: serde_json::Value =
serde_json::from_str(&tc.input_json).unwrap_or_else(|e| {
warn!(
tool = %tc.name,
error = %e,
"Malformed tool call JSON, using empty object"
);
serde_json::Value::Object(serde_json::Map::new())
});
let _ = self
.event_tx
.send(AgentEvent::ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
input: input.clone(),
})
.await;
assistant_content.push(ContentBlock::ToolUse {
id: tc.id.clone(),
name: tc.name.clone(),
input: input.clone(),
});
parsed_inputs.push(input);
}
self.messages.push(Message {
role: saorsa_ai::Role::Assistant,
content: assistant_content,
});
match stop_reason {
Some(StopReason::ToolUse) if !tool_calls.is_empty() => {
let tool_results = self.execute_tool_calls(&tool_calls, &parsed_inputs).await;
for result in &tool_results {
self.messages
.push(Message::tool_result(&result.id, &result.output));
}
let _ = self
.event_tx
.send(AgentEvent::TurnEnd {
turn,
reason: TurnEndReason::ToolUse,
})
.await;
}
Some(StopReason::MaxTokens) => {
let _ = self
.event_tx
.send(AgentEvent::TurnEnd {
turn,
reason: TurnEndReason::MaxTokens,
})
.await;
break;
}
_ => {
let _ = self
.event_tx
.send(AgentEvent::TurnEnd {
turn,
reason: TurnEndReason::EndTurn,
})
.await;
break;
}
}
}
Ok(final_text)
}
async fn execute_tool_calls(
&self,
tool_calls: &[ToolCallInfo],
inputs: &[serde_json::Value],
) -> Vec<ToolResultInfo> {
let mut results = Vec::new();
for (tc, input) in tool_calls.iter().zip(inputs.iter()) {
let (output, success) = match self.tools.get(&tc.name) {
Some(tool) => match tool.execute(input.clone()).await {
Ok(result) => (result, true),
Err(e) => (format!("Error: {e}"), false),
},
None => (format!("Unknown tool: {}", tc.name), false),
};
let _ = self
.event_tx
.send(AgentEvent::ToolResult {
id: tc.id.clone(),
name: tc.name.clone(),
output: output.clone(),
success,
})
.await;
results.push(ToolResultInfo {
id: tc.id.clone(),
output,
});
}
results
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
}
#[derive(Debug)]
struct ToolCallInfo {
id: String,
name: String,
input_json: String,
}
#[derive(Debug)]
struct ToolResultInfo {
id: String,
output: String,
}
pub fn default_tools(working_dir: impl Into<std::path::PathBuf>) -> ToolRegistry {
use crate::tools::{
BashTool, EditTool, FindTool, GrepTool, LsTool, ReadTool, WebSearchTool, WriteTool,
};
use std::path::PathBuf;
let wd: PathBuf = working_dir.into();
let mut registry = ToolRegistry::new();
registry.register(Box::new(BashTool::new(wd.clone())));
registry.register(Box::new(ReadTool::new(wd.clone())));
registry.register(Box::new(WriteTool::new(wd.clone())));
registry.register(Box::new(EditTool::new(wd.clone())));
registry.register(Box::new(GrepTool::new(wd.clone())));
registry.register(Box::new(FindTool::new(wd.clone())));
registry.register(Box::new(LsTool::new(wd)));
registry.register(Box::new(WebSearchTool::new()));
registry
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::event_channel;
struct MockProvider {
events: Vec<StreamEvent>,
}
#[async_trait::async_trait]
impl saorsa_ai::Provider for MockProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> saorsa_ai::Result<saorsa_ai::CompletionResponse> {
Err(saorsa_ai::SaorsaAiError::Internal("not implemented".into()))
}
}
#[async_trait::async_trait]
impl StreamingProvider for MockProvider {
async fn stream(
&self,
_request: CompletionRequest,
) -> saorsa_ai::Result<tokio::sync::mpsc::Receiver<saorsa_ai::Result<StreamEvent>>>
{
let (tx, rx) = tokio::sync::mpsc::channel(64);
let events = self.events.clone();
tokio::spawn(async move {
for event in events {
if tx.send(Ok(event)).await.is_err() {
break;
}
}
});
Ok(rx)
}
}
fn mock_text_provider(text: &str) -> Box<dyn StreamingProvider> {
Box::new(MockProvider {
events: vec![
StreamEvent::MessageStart {
id: "msg_1".into(),
model: "test".into(),
usage: saorsa_ai::Usage::default(),
},
StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlock::Text {
text: String::new(),
},
},
StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentDelta::TextDelta {
text: text.to_string(),
},
},
StreamEvent::ContentBlockStop { index: 0 },
StreamEvent::MessageDelta {
stop_reason: Some(StopReason::EndTurn),
usage: saorsa_ai::Usage::default(),
},
StreamEvent::MessageStop,
],
})
}
#[tokio::test]
async fn agent_simple_text_response() {
let provider = mock_text_provider("Hello, world!");
let config = AgentConfig::default();
let tools = ToolRegistry::new();
let (tx, mut rx) = event_channel(64);
let mut agent = AgentLoop::new(provider, config, tools, tx);
let handle = tokio::spawn(async move { agent.run("Hi").await });
let mut events = Vec::new();
while let Some(event) = rx.recv().await {
events.push(event);
}
let result = handle.await;
assert!(result.is_ok());
if let Ok(Ok(text)) = result {
assert_eq!(text, "Hello, world!");
}
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::TurnStart { turn: 1 }))
);
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::TextDelta { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::TextComplete { .. }))
);
assert!(events.iter().any(|e| matches!(
e,
AgentEvent::TurnEnd {
reason: TurnEndReason::EndTurn,
..
}
)));
}
#[tokio::test]
async fn agent_max_turns_limit() {
let provider = mock_text_provider("response");
let config = AgentConfig::default().max_turns(0);
let tools = ToolRegistry::new();
let (tx, _rx) = event_channel(64);
let mut agent = AgentLoop::new(provider, config, tools, tx);
let result = agent.run("Hi").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn agent_tracks_messages() {
let provider = mock_text_provider("response");
let config = AgentConfig::default();
let tools = ToolRegistry::new();
let (tx, _rx) = event_channel(64);
let mut agent = AgentLoop::new(provider, config, tools, tx);
let _ = agent.run("Hello").await;
let msgs = agent.messages();
assert_eq!(msgs.len(), 2);
}
#[test]
fn default_tools_registers_all() {
let cwd = std::env::current_dir();
assert!(cwd.is_ok());
let Ok(dir) = cwd else { unreachable!() };
let registry = super::default_tools(dir);
assert_eq!(registry.len(), 8);
let names = registry.names();
assert!(names.contains(&"bash"));
assert!(names.contains(&"read"));
assert!(names.contains(&"write"));
assert!(names.contains(&"edit"));
assert!(names.contains(&"grep"));
assert!(names.contains(&"find"));
assert!(names.contains(&"ls"));
assert!(names.contains(&"web_search"));
}
}