use std::io;
use std::sync::Arc;
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::controller::{
ControllerEvent, ControllerInputPayload, Executable, LLMController, LLMSessionConfig, LLMTool,
ListSkillsTool, PermissionRegistry, ToolRegistry, UserInteractionRegistry,
};
use crate::skills::{SkillDiscovery, SkillDiscoveryError, SkillRegistry, SkillReloadResult};
use super::config::{AgentConfig, LLMRegistry, load_config};
use super::error::AgentError;
use super::logger::Logger;
use super::messages::UiMessage;
use super::messages::channels::DEFAULT_CHANNEL_SIZE;
use super::router::InputRouter;
pub type ToControllerTx = mpsc::Sender<ControllerInputPayload>;
pub type ToControllerRx = mpsc::Receiver<ControllerInputPayload>;
pub type FromControllerTx = mpsc::Sender<UiMessage>;
pub type FromControllerRx = mpsc::Receiver<UiMessage>;
pub struct AgentAir {
#[allow(dead_code)]
logger: Logger,
name: String,
version: String,
runtime: Runtime,
controller: Arc<LLMController>,
llm_registry: Option<LLMRegistry>,
to_controller_tx: ToControllerTx,
to_controller_rx: Option<ToControllerRx>,
from_controller_tx: FromControllerTx,
from_controller_rx: Option<FromControllerRx>,
cancel_token: CancellationToken,
user_interaction_registry: Arc<UserInteractionRegistry>,
permission_registry: Arc<PermissionRegistry>,
tool_definitions: Vec<LLMTool>,
error_no_session: Option<String>,
skill_registry: Arc<SkillRegistry>,
skill_discovery: SkillDiscovery,
}
impl AgentAir {
pub fn new<C: AgentConfig>(config: &C) -> io::Result<Self> {
let logger = Logger::new(config.log_prefix())?;
tracing::info!("{} agent initialized", config.name());
let llm_registry = load_config(config);
if llm_registry.is_empty() {
tracing::warn!(
"No LLM providers configured. Set ANTHROPIC_API_KEY or create ~/{}",
config.config_path()
);
} else {
tracing::info!(
"Loaded {} LLM provider(s): {:?}",
llm_registry.providers().len(),
llm_registry.providers()
);
}
let runtime = Runtime::new()
.map_err(|e| io::Error::other(format!("Failed to create runtime: {}", e)))?;
let channel_size = config.channel_buffer_size().unwrap_or(DEFAULT_CHANNEL_SIZE);
tracing::debug!("Using channel buffer size: {}", channel_size);
let (to_controller_tx, to_controller_rx) =
mpsc::channel::<ControllerInputPayload>(channel_size);
let (from_controller_tx, from_controller_rx) = mpsc::channel::<UiMessage>(channel_size);
let (interaction_event_tx, mut interaction_event_rx) =
mpsc::channel::<ControllerEvent>(channel_size);
let user_interaction_registry =
Arc::new(UserInteractionRegistry::new(interaction_event_tx));
let ui_tx_for_interactions = from_controller_tx.clone();
runtime.spawn(async move {
while let Some(event) = interaction_event_rx.recv().await {
let msg = convert_controller_event_to_ui_message(event);
if let Err(e) = ui_tx_for_interactions.send(msg).await {
tracing::warn!("Failed to send user interaction event to UI: {}", e);
}
}
});
let (permission_event_tx, mut permission_event_rx) =
mpsc::channel::<ControllerEvent>(channel_size);
let permission_registry = Arc::new(PermissionRegistry::new(permission_event_tx));
let ui_tx_for_permissions = from_controller_tx.clone();
runtime.spawn(async move {
while let Some(event) = permission_event_rx.recv().await {
let msg = convert_controller_event_to_ui_message(event);
if let Err(e) = ui_tx_for_permissions.send(msg).await {
tracing::warn!("Failed to send permission event to UI: {}", e);
}
}
});
let controller = Arc::new(LLMController::new(
permission_registry.clone(),
Some(from_controller_tx.clone()),
Some(channel_size),
));
let cancel_token = CancellationToken::new();
Ok(Self {
logger,
name: config.name().to_string(),
version: "0.1.0".to_string(),
runtime,
controller,
llm_registry: Some(llm_registry),
to_controller_tx,
to_controller_rx: Some(to_controller_rx),
from_controller_tx,
from_controller_rx: Some(from_controller_rx),
cancel_token,
user_interaction_registry,
permission_registry,
tool_definitions: Vec::new(),
error_no_session: None,
skill_registry: Arc::new(SkillRegistry::new()),
skill_discovery: SkillDiscovery::new(),
})
}
pub fn with_config(
name: impl Into<String>,
config_path: impl Into<String>,
system_prompt: impl Into<String>,
) -> io::Result<Self> {
let config = super::config::SimpleConfig::new(name, config_path, system_prompt);
Self::new(&config)
}
pub fn set_error_no_session(&mut self, message: impl Into<String>) -> &mut Self {
self.error_no_session = Some(message.into());
self
}
pub fn error_no_session(&self) -> Option<&str> {
self.error_no_session.as_deref()
}
pub fn set_version(&mut self, version: impl Into<String>) {
self.version = version.into();
}
pub fn version(&self) -> &str {
&self.version
}
pub fn load_environment_context(&mut self) -> &mut Self {
if let Some(registry) = self.llm_registry.take() {
self.llm_registry = Some(registry.with_environment_context());
tracing::info!("Environment context loaded into system prompt");
}
self
}
pub fn register_tools<F>(&mut self, f: F) -> Result<(), AgentError>
where
F: FnOnce(
&Arc<ToolRegistry>,
&Arc<UserInteractionRegistry>,
&Arc<PermissionRegistry>,
) -> Result<Vec<LLMTool>, String>,
{
let tool_defs = f(
self.controller.tool_registry(),
&self.user_interaction_registry,
&self.permission_registry,
)
.map_err(AgentError::ToolRegistration)?;
self.tool_definitions = tool_defs;
Ok(())
}
pub fn register_tools_async<F, Fut>(&mut self, f: F) -> Result<(), AgentError>
where
F: FnOnce(Arc<ToolRegistry>, Arc<UserInteractionRegistry>, Arc<PermissionRegistry>) -> Fut,
Fut: std::future::Future<Output = Result<Vec<LLMTool>, String>>,
{
let tool_defs = self
.runtime
.block_on(f(
self.controller.tool_registry().clone(),
self.user_interaction_registry.clone(),
self.permission_registry.clone(),
))
.map_err(AgentError::ToolRegistration)?;
self.tool_definitions = tool_defs;
Ok(())
}
pub fn start_background_tasks(&mut self) {
tracing::info!("{} starting background tasks", self.name);
let controller = self.controller.clone();
self.runtime.spawn(async move {
controller.start().await;
});
tracing::info!("Controller started");
if let Some(to_controller_rx) = self.to_controller_rx.take() {
let router = InputRouter::new(
self.controller.clone(),
to_controller_rx,
self.cancel_token.clone(),
);
self.runtime.spawn(async move {
router.run().await;
});
tracing::info!("InputRouter started");
}
}
async fn create_session_internal(
controller: &Arc<LLMController>,
mut config: LLMSessionConfig,
tools: &[LLMTool],
skill_registry: &Arc<SkillRegistry>,
) -> Result<i64, crate::client::error::LlmError> {
let skills_xml = skill_registry.to_prompt_xml();
if !skills_xml.is_empty() {
config.system_prompt = Some(match config.system_prompt {
Some(prompt) => format!("{}\n\n{}", prompt, skills_xml),
None => skills_xml,
});
}
let id = controller.create_session(config).await?;
if !tools.is_empty()
&& let Some(session) = controller.get_session(id).await
{
session.set_tools(tools.to_vec()).await;
}
Ok(id)
}
pub fn create_initial_session(&mut self) -> Result<(i64, String, i32), AgentError> {
let registry = self
.llm_registry
.as_ref()
.ok_or_else(|| AgentError::NoConfiguration("No LLM registry available".to_string()))?;
let config = registry.get_default().ok_or_else(|| {
AgentError::NoConfiguration("No default LLM provider configured".to_string())
})?;
let model = config.model.clone();
let context_limit = config.context_limit;
let controller = self.controller.clone();
let tool_definitions = self.tool_definitions.clone();
let skill_registry = self.skill_registry.clone();
let session_id = self.runtime.block_on(Self::create_session_internal(
&controller,
config.clone(),
&tool_definitions,
&skill_registry,
))?;
tracing::info!(
session_id = session_id,
model = %model,
"Created initial session"
);
Ok((session_id, model, context_limit))
}
pub fn create_session(&self, config: LLMSessionConfig) -> Result<i64, AgentError> {
let controller = self.controller.clone();
let tool_definitions = self.tool_definitions.clone();
let skill_registry = self.skill_registry.clone();
self.runtime
.block_on(Self::create_session_internal(
&controller,
config,
&tool_definitions,
&skill_registry,
))
.map_err(AgentError::from)
}
pub fn shutdown(&self) {
tracing::info!("{} shutting down", self.name);
self.cancel_token.cancel();
let controller = self.controller.clone();
self.runtime.block_on(async move {
controller.shutdown().await;
});
tracing::info!("{} shutdown complete", self.name);
}
pub fn run_with_frontend<E, I, P>(
&mut self,
event_sink: E,
mut input_source: I,
permission_policy: P,
) -> io::Result<()>
where
E: super::interface::EventSink,
I: super::interface::InputSource,
P: super::interface::PermissionPolicy,
{
use super::interface::PolicyDecision;
use crate::permissions::{BatchPermissionResponse, PermissionPanelResponse};
use std::sync::Arc;
tracing::info!("{} starting with custom frontend", self.name);
let sink = Arc::new(event_sink);
let policy = Arc::new(permission_policy);
let controller = self.controller.clone();
self.runtime.spawn(async move {
controller.start().await;
});
tracing::info!("Controller started");
if let Some(mut from_controller_rx) = self.from_controller_rx.take() {
let sink_clone = sink.clone();
let policy_clone = policy.clone();
let permission_registry = self.permission_registry.clone();
let user_interaction_registry = self.user_interaction_registry.clone();
self.runtime.spawn(async move {
while let Some(event) = from_controller_rx.recv().await {
match &event {
UiMessage::PermissionRequired {
tool_use_id,
request,
..
} => {
match policy_clone.decide(request) {
PolicyDecision::AskUser => {
}
decision => {
let response = match decision {
PolicyDecision::Allow => PermissionPanelResponse {
granted: true,
grant: None,
message: None,
},
PolicyDecision::AllowWithGrant(grant) => {
PermissionPanelResponse {
granted: true,
grant: Some(grant),
message: None,
}
}
PolicyDecision::Deny { reason } => {
PermissionPanelResponse {
granted: false,
grant: None,
message: reason,
}
}
PolicyDecision::AskUser => unreachable!(),
};
if let Err(e) = permission_registry
.respond_to_request(tool_use_id, response)
.await
{
tracing::warn!(
"Failed to respond to permission request: {}",
e
);
}
continue; }
}
}
UiMessage::BatchPermissionRequired { batch, .. } => {
let mut all_handled = true;
let mut approved_grants = Vec::new();
let mut denied_ids = Vec::new();
for request in &batch.requests {
match policy_clone.decide(request) {
PolicyDecision::Allow => {
}
PolicyDecision::AllowWithGrant(grant) => {
approved_grants.push(grant);
}
PolicyDecision::Deny { .. } => {
denied_ids.push(request.id.clone());
}
PolicyDecision::AskUser => {
all_handled = false;
break;
}
}
}
if all_handled {
let response = if denied_ids.is_empty() {
BatchPermissionResponse::all_granted(
&batch.batch_id,
approved_grants,
)
} else {
BatchPermissionResponse::all_denied(&batch.batch_id, denied_ids)
};
if let Err(e) = permission_registry
.respond_to_batch(&batch.batch_id, response)
.await
{
tracing::warn!(
"Failed to respond to batch permission request: {}",
e
);
}
continue; }
}
UiMessage::UserInteractionRequired { tool_use_id, .. } => {
if !policy_clone.supports_interaction() {
if let Err(e) = user_interaction_registry.cancel(tool_use_id).await
{
tracing::warn!("Failed to cancel user interaction: {}", e);
}
tracing::debug!("Auto-cancelled user interaction in headless mode");
continue; }
}
_ => {}
}
if let Err(e) = sink_clone.send(event) {
tracing::warn!("Failed to send event to sink: {}", e);
}
}
});
}
match self.create_initial_session() {
Ok((session_id, model, _)) => {
tracing::info!(session_id, model = %model, "Created initial session");
}
Err(e) => {
tracing::warn!(error = %e, "No initial session created");
}
}
let to_controller_tx = self.to_controller_tx.clone();
self.runtime.block_on(async {
while let Some(input) = input_source.recv().await {
if let Err(e) = to_controller_tx.send(input).await {
tracing::error!(error = %e, "Failed to send input to controller");
break;
}
}
});
self.shutdown();
tracing::info!("{} stopped", self.name);
Ok(())
}
pub fn to_controller_tx(&self) -> ToControllerTx {
self.to_controller_tx.clone()
}
pub fn take_from_controller_rx(&mut self) -> Option<FromControllerRx> {
self.from_controller_rx.take()
}
pub fn controller(&self) -> &Arc<LLMController> {
&self.controller
}
pub fn runtime(&self) -> &Runtime {
&self.runtime
}
pub fn runtime_handle(&self) -> tokio::runtime::Handle {
self.runtime.handle().clone()
}
pub fn user_interaction_registry(&self) -> &Arc<UserInteractionRegistry> {
&self.user_interaction_registry
}
pub fn permission_registry(&self) -> &Arc<PermissionRegistry> {
&self.permission_registry
}
pub async fn remove_session(&self, session_id: i64) -> bool {
let removed = self.controller.remove_session(session_id).await;
self.permission_registry.cancel_session(session_id).await;
self.user_interaction_registry
.cancel_session(session_id)
.await;
self.controller
.tool_registry()
.cleanup_session(session_id)
.await;
if removed {
tracing::info!(session_id, "Session removed with full cleanup");
}
removed
}
pub fn llm_registry(&self) -> Option<&LLMRegistry> {
self.llm_registry.as_ref()
}
pub fn take_llm_registry(&mut self) -> Option<LLMRegistry> {
self.llm_registry.take()
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn from_controller_tx(&self) -> FromControllerTx {
self.from_controller_tx.clone()
}
pub fn tool_definitions(&self) -> &[LLMTool] {
&self.tool_definitions
}
pub fn skill_registry(&self) -> &Arc<SkillRegistry> {
&self.skill_registry
}
pub fn register_list_skills_tool(&mut self) -> Result<LLMTool, AgentError> {
let tool = ListSkillsTool::new(self.skill_registry.clone());
let llm_tool = tool.to_llm_tool();
self.runtime
.block_on(async {
self.controller
.tool_registry()
.register(Arc::new(tool))
.await
})
.map_err(|e| AgentError::ToolRegistration(e.to_string()))?;
self.tool_definitions.push(llm_tool.clone());
tracing::info!("Registered list_skills tool");
Ok(llm_tool)
}
pub fn add_skill_path(&mut self, path: std::path::PathBuf) -> &mut Self {
self.skill_discovery.add_path(path);
self
}
pub fn load_skills(&mut self) -> (usize, Vec<SkillDiscoveryError>) {
let results = self.skill_discovery.discover();
self.register_discovered_skills(results)
}
pub fn load_skills_from(
&self,
paths: Vec<std::path::PathBuf>,
) -> (usize, Vec<SkillDiscoveryError>) {
let mut discovery = SkillDiscovery::empty();
for path in paths {
discovery.add_path(path);
}
let results = discovery.discover();
self.register_discovered_skills(results)
}
fn register_discovered_skills(
&self,
results: Vec<Result<crate::skills::Skill, SkillDiscoveryError>>,
) -> (usize, Vec<SkillDiscoveryError>) {
let mut errors = Vec::new();
let mut count = 0;
for result in results {
match result {
Ok(skill) => {
let skill_name = skill.metadata.name.clone();
let skill_path = skill.path.clone();
let replaced = self.skill_registry.register(skill);
if let Some(old_skill) = replaced {
tracing::warn!(
skill_name = %skill_name,
new_path = %skill_path.display(),
old_path = %old_skill.path.display(),
"Duplicate skill name detected - replaced existing skill"
);
}
tracing::info!(
skill_name = %skill_name,
skill_path = %skill_path.display(),
"Loaded skill"
);
count += 1;
}
Err(e) => {
tracing::warn!(
path = %e.path.display(),
error = %e.message,
"Failed to load skill"
);
errors.push(e);
}
}
}
tracing::info!("Loaded {} skill(s)", count);
(count, errors)
}
pub fn reload_skills(&mut self) -> SkillReloadResult {
let current_names: std::collections::HashSet<String> =
self.skill_registry.names().into_iter().collect();
let results = self.skill_discovery.discover();
let mut discovered_names = std::collections::HashSet::new();
let mut result = SkillReloadResult::default();
for discovery_result in results {
match discovery_result {
Ok(skill) => {
let name = skill.metadata.name.clone();
discovered_names.insert(name.clone());
if !current_names.contains(&name) {
tracing::info!(skill_name = %name, "Added new skill");
result.added.push(name);
}
self.skill_registry.register(skill);
}
Err(e) => {
tracing::warn!(
path = %e.path.display(),
error = %e.message,
"Failed to load skill during reload"
);
result.errors.push(e);
}
}
}
for name in ¤t_names {
if !discovered_names.contains(name) {
tracing::info!(skill_name = %name, "Removed skill");
self.skill_registry.unregister(name);
result.removed.push(name.clone());
}
}
tracing::info!(
added = result.added.len(),
removed = result.removed.len(),
errors = result.errors.len(),
"Skills reloaded"
);
result
}
pub fn skills_prompt_xml(&self) -> String {
self.skill_registry.to_prompt_xml()
}
pub async fn refresh_session_skills(&self, session_id: i64) -> Result<(), AgentError> {
let skills_xml = self.skills_prompt_xml();
if skills_xml.is_empty() {
return Ok(());
}
let session = self
.controller
.get_session(session_id)
.await
.ok_or(AgentError::SessionNotFound(session_id))?;
let current_prompt = session.system_prompt().await.unwrap_or_default();
let new_prompt = if current_prompt.contains("<available_skills>") {
replace_skills_section(¤t_prompt, &skills_xml)
} else if current_prompt.is_empty() {
skills_xml
} else {
format!("{}\n\n{}", current_prompt, skills_xml)
};
session.set_system_prompt(new_prompt).await;
tracing::debug!(session_id, "Refreshed session skills");
Ok(())
}
}
fn replace_skills_section(prompt: &str, new_skills_xml: &str) -> String {
if let Some(start) = prompt.find("<available_skills>")
&& let Some(end) = prompt.find("</available_skills>")
{
let end = end + "</available_skills>".len();
let mut result = String::with_capacity(prompt.len());
result.push_str(&prompt[..start]);
result.push_str(new_skills_xml);
result.push_str(&prompt[end..]);
return result;
}
format!("{}\n\n{}", prompt, new_skills_xml)
}
pub fn convert_controller_event_to_ui_message(event: ControllerEvent) -> UiMessage {
match event {
ControllerEvent::StreamStart { session_id, .. } => {
UiMessage::System {
session_id,
message: String::new(),
}
}
ControllerEvent::TextChunk {
session_id,
text,
turn_id,
} => UiMessage::TextChunk {
session_id,
turn_id,
text,
input_tokens: 0,
output_tokens: 0,
},
ControllerEvent::ToolUseStart {
session_id,
tool_name,
turn_id,
..
} => UiMessage::Display {
session_id,
turn_id,
message: format!("Executing tool: {}", tool_name),
},
ControllerEvent::ToolUse {
session_id,
tool,
display_name,
display_title,
turn_id,
} => UiMessage::ToolExecuting {
session_id,
turn_id,
tool_use_id: tool.id.clone(),
display_name: display_name.unwrap_or_else(|| tool.name.clone()),
display_title: display_title.unwrap_or_default(),
},
ControllerEvent::Complete {
session_id,
turn_id,
stop_reason,
} => UiMessage::Complete {
session_id,
turn_id,
input_tokens: 0,
output_tokens: 0,
stop_reason,
},
ControllerEvent::Error {
session_id,
error,
turn_id,
} => UiMessage::Error {
session_id,
turn_id,
error,
},
ControllerEvent::TokenUpdate {
session_id,
input_tokens,
output_tokens,
context_limit,
} => UiMessage::TokenUpdate {
session_id,
turn_id: None,
input_tokens,
output_tokens,
context_limit,
},
ControllerEvent::ToolResult {
session_id,
tool_use_id,
status,
error,
turn_id,
..
} => UiMessage::ToolCompleted {
session_id,
turn_id,
tool_use_id,
status,
error,
},
ControllerEvent::CommandComplete {
session_id,
command,
success,
message,
} => UiMessage::CommandComplete {
session_id,
command,
success,
message,
},
ControllerEvent::UserInteractionRequired {
session_id,
tool_use_id,
request,
turn_id,
} => UiMessage::UserInteractionRequired {
session_id,
tool_use_id,
request,
turn_id,
},
ControllerEvent::PermissionRequired {
session_id,
tool_use_id,
request,
turn_id,
} => UiMessage::PermissionRequired {
session_id,
tool_use_id,
request,
turn_id,
},
ControllerEvent::BatchPermissionRequired {
session_id,
batch,
turn_id,
} => UiMessage::BatchPermissionRequired {
session_id,
batch,
turn_id,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::controller::TurnId;
#[test]
fn test_convert_text_chunk_event() {
let event = ControllerEvent::TextChunk {
session_id: 1,
text: "Hello".to_string(),
turn_id: Some(TurnId::new_user_turn(1)),
};
let msg = convert_controller_event_to_ui_message(event);
match msg {
UiMessage::TextChunk {
session_id, text, ..
} => {
assert_eq!(session_id, 1);
assert_eq!(text, "Hello");
}
_ => panic!("Expected TextChunk message"),
}
}
#[test]
fn test_convert_error_event() {
let event = ControllerEvent::Error {
session_id: 1,
error: "Test error".to_string(),
turn_id: None,
};
let msg = convert_controller_event_to_ui_message(event);
match msg {
UiMessage::Error {
session_id, error, ..
} => {
assert_eq!(session_id, 1);
assert_eq!(error, "Test error");
}
_ => panic!("Expected Error message"),
}
}
#[test]
fn test_replace_skills_section_replaces_existing() {
let prompt = "System prompt.\n\n<available_skills>\n <skill>old</skill>\n</available_skills>\n\nMore text.";
let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
let result = replace_skills_section(prompt, new_xml);
assert!(result.contains("<skill>new</skill>"));
assert!(!result.contains("<skill>old</skill>"));
assert!(result.contains("System prompt."));
assert!(result.contains("More text."));
}
#[test]
fn test_replace_skills_section_no_existing() {
let prompt = "System prompt without skills.";
let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
let result = replace_skills_section(prompt, new_xml);
assert!(result.contains("System prompt without skills."));
assert!(result.contains("<skill>new</skill>"));
}
#[test]
fn test_replace_skills_section_malformed_no_closing_tag() {
let prompt =
"System prompt.\n\n<available_skills>\n <skill>old</skill>\n\nNo closing tag.";
let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
let result = replace_skills_section(prompt, new_xml);
assert!(result.contains("<skill>old</skill>"));
assert!(result.contains("<skill>new</skill>"));
}
#[test]
fn test_replace_skills_section_at_end() {
let prompt =
"System prompt.\n\n<available_skills>\n <skill>old</skill>\n</available_skills>";
let new_xml = "<available_skills>\n <skill>new</skill>\n</available_skills>";
let result = replace_skills_section(prompt, new_xml);
assert!(result.contains("<skill>new</skill>"));
assert!(!result.contains("<skill>old</skill>"));
assert!(result.starts_with("System prompt."));
}
}