use std::{collections::HashMap, sync::Arc, time::Duration};
use anyhow::Result;
use tokio::{
sync::{mpsc, RwLock},
task::{self},
};
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use uuid::Uuid;
use crate::{
agent::{self, RunningAgent},
frontend::App,
git, indexing,
repository::Repository,
util::accept_non_zero_exit,
};
use super::{
command::{Command, CommandEvent},
responder::{CommandResponse, Responder},
};
pub struct CommandHandler {
rx: Option<mpsc::UnboundedReceiver<CommandEvent>>,
tx: mpsc::UnboundedSender<CommandEvent>,
repository: Arc<Repository>,
agents: Arc<RwLock<HashMap<Uuid, RunningAgent>>>,
}
impl CommandHandler {
pub fn from_repository(repository: impl Into<Repository>) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
CommandHandler {
rx: Some(rx),
tx,
repository: Arc::new(repository.into()),
agents: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register_ui(&mut self, app: &mut App) {
app.command_tx = Some(self.tx.clone());
}
pub fn start(mut self) -> AbortOnDropHandle<()> {
let repository = Arc::clone(&self.repository);
let mut rx = self.rx.take().expect("Expected a receiver");
let this_handler = Arc::new(self);
AbortOnDropHandle::new(task::spawn(async move {
let mut joinset = tokio::task::JoinSet::new();
while let Some(event) = rx.recv().await {
if event.command().is_quit() {
tracing::warn!("Backend received quit command, shutting down");
joinset.shutdown().await;
tracing::warn!("Backend shutdown complete");
break;
}
let repository = Arc::clone(&repository);
let this_handler = Arc::clone(&this_handler);
joinset.spawn(async move {
let result = this_handler.handle_command_event(&repository, &event, &event.command()).await;
event.responder().send(CommandResponse::Completed(event.uuid()));
if let Err(error) = result {
tracing::error!(?error, cmd = %event.command(), "Failed to handle command {cmd} with error {error:#}", cmd= event.command());
event.responder().system_message(&format!(
"Failed to handle command: {error:#}"
));
};
});
}
tracing::warn!("CommandHandler shutting down");
}))
}
#[tracing::instrument(skip_all, fields(otel.name = %cmd.to_string(), uuid = %event.uuid()), err)]
async fn handle_command_event(
&self,
repository: &Repository,
event: &CommandEvent,
cmd: &Command,
) -> Result<()> {
let now = std::time::Instant::now();
tracing::warn!("Handling command {cmd}");
#[allow(clippy::match_wildcard_for_single_variants)]
match cmd {
Command::StopAgent => {
self.stop_agent(event.uuid(), event.clone_responder())
.await?;
}
Command::IndexRepository { .. } => {
indexing::index_repository(repository, Some(event.clone_responder())).await?;
}
Command::ShowConfig => event
.responder()
.system_message(&toml::to_string_pretty(repository.config())?),
Command::Chat { ref message } => {
let message = message.clone();
let agent = self
.find_or_start_agent_by_uuid(event.uuid(), &message, event.clone_responder())
.await?;
let token = agent.cancel_token.clone();
tokio::select! {
() = token.cancelled() => Ok(()),
result = agent.query(&message) => result,
}?;
}
Command::Diff => {
let Some(agent) = self.find_agent_by_uuid(event.uuid()).await else {
event
.responder()
.system_message("No agent found (yet), is it starting up?");
return Ok(());
};
let base_sha = &agent.agent_environment.start_ref;
let diff = git::util::diff(agent.executor.as_ref(), &base_sha).await?;
event.responder().system_message(&diff);
}
Command::Exec { cmd } => {
let Some(agent) = self.find_agent_by_uuid(event.uuid()).await else {
event
.responder()
.system_message("No agent found (yet), is it starting up?");
return Ok(());
};
let output = accept_non_zero_exit(agent.executor.exec_cmd(cmd).await)?.output;
event.responder().system_message(&output);
}
Command::RetryChat => {
let Some(agent) = self.find_agent_by_uuid(event.uuid()).await else {
event
.responder()
.system_message("No agent found (yet), is it starting up?");
return Ok(());
};
let mut token = agent.cancel_token.clone();
if token.is_cancelled() {
if let Some(agent) = self.agents.write().await.get_mut(&event.uuid()) {
agent.cancel_token = CancellationToken::new();
token = agent.cancel_token.clone();
}
}
agent.agent_context.redrive().await;
tokio::select! {
() = token.cancelled() => Ok(()),
result = agent.run() => result,
}?;
}
Command::Quit { .. } => unreachable!("Quit should be handled earlier"),
}
tokio::time::sleep(Duration::from_millis(100)).await;
let mut elapsed = now.elapsed();
if cfg!(debug_assertions) {
elapsed = Duration::from_secs(0);
}
event.responder().system_message(&format!(
"Command {cmd} successful in {} seconds",
elapsed.as_secs_f64().round()
));
Ok(())
}
async fn find_or_start_agent_by_uuid(
&self,
uuid: Uuid,
query: &str,
responder: Arc<dyn Responder>,
) -> Result<RunningAgent> {
if let Some(agent) = self.find_agent_by_uuid(uuid).await {
if let Some(agent) = self.agents.write().await.get_mut(&uuid) {
agent.cancel_token = CancellationToken::new();
}
return Ok(agent);
}
let running_agent = agent::start_agent(uuid, &self.repository, query, responder).await?;
let cloned = running_agent.clone();
self.agents.write().await.insert(uuid, running_agent);
Ok(cloned)
}
async fn find_agent_by_uuid(&self, uuid: Uuid) -> Option<RunningAgent> {
if let Some(agent) = self.agents.read().await.get(&uuid) {
return Some(agent.clone());
}
None
}
async fn stop_agent(&self, uuid: Uuid, responder: Arc<dyn Responder>) -> Result<()> {
let agents = self.agents.read().await;
let Some(agent) = agents.get(&uuid) else {
responder.system_message("No agent found (yet), is it starting up?");
return Ok(());
};
if agent.cancel_token.is_cancelled() {
responder.system_message("Agent already stopped");
return Ok(());
}
agent.stop().await;
responder.system_message("Agent stopped");
Ok(())
}
}