bevy_rig 0.1.0

Bevy ECS primitives and systems for modeling providers, agents, tools, sessions, runs, and workflows on top of Rig.
Documentation
use std::{
    any::Any,
    panic::{AssertUnwindSafe, catch_unwind},
    sync::Arc,
};

use bevy_ecs::prelude::*;
#[allow(deprecated)]
use rig::client::builder::AnyClient;
use rig::{
    client::ProviderClient,
    completion::{Chat, Message as RigMessage, Prompt, PromptError},
    providers::{
        anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic,
        llamafile, mira, mistral, moonshot, ollama, openai, openrouter, perplexity, together, xai,
    },
};
use thiserror::Error;
use tokio::runtime::Runtime;

use crate::{
    agent::{AgentSpec, AgentToolRefs},
    provider::{ProviderKind, ProviderSpec},
    run::{
        Run, RunFailure, RunFinalized, RunOwner, RunPrompt, RunRequest, RunResultText, RunSession,
        RunStatus, RunStreamBuffer,
    },
    session::{self, ChatMessageRole},
};

#[derive(Resource, Clone)]
pub struct RigRuntime(pub Arc<Runtime>);

impl Default for RigRuntime {
    fn default() -> Self {
        Self(Arc::new(Runtime::new().expect(
            "bevy_rig could not create a Tokio runtime for Rig execution",
        )))
    }
}

#[derive(Debug, Error)]
pub enum RigExecutionError {
    #[error("agent entity {0:?} is missing AgentSpec")]
    MissingAgentSpec(Entity),
    #[error("agent entity {0:?} is not configured with a provider")]
    MissingProvider(Entity),
    #[error("provider entity {0:?} is missing ProviderSpec")]
    MissingProviderSpec(Entity),
    #[error(
        "agent entity {agent:?} has {tool_count} attached Bevy tool(s), but the Rig tool bridge is not implemented yet"
    )]
    UnsupportedTools { agent: Entity, tool_count: usize },
    #[error("provider {provider:?} could not be initialized from the environment: {error}")]
    ProviderUnavailable {
        provider: ProviderKind,
        error: String,
    },
    #[error("provider {provider:?} does not expose Rig completion capabilities")]
    CompletionUnavailable { provider: ProviderKind },
    #[error("session entity {0:?} does not exist")]
    MissingSession(Entity),
    #[error("{0}")]
    PromptFailure(#[from] PromptError),
}

pub(crate) fn execute_agent_prompt(
    world: &mut World,
    agent: Entity,
    prompt: &str,
    session: Option<Entity>,
    current_user_message: Option<&str>,
) -> Result<String, RigExecutionError> {
    let runtime = world.resource::<RigRuntime>().0.clone();
    let agent_spec = world
        .get::<AgentSpec>(agent)
        .cloned()
        .ok_or(RigExecutionError::MissingAgentSpec(agent))?;
    let provider = agent_spec
        .provider
        .ok_or(RigExecutionError::MissingProvider(agent))?;
    let provider_spec = world
        .get::<ProviderSpec>(provider)
        .cloned()
        .ok_or(RigExecutionError::MissingProviderSpec(provider))?;
    let tool_count = world
        .get::<AgentToolRefs>(agent)
        .map(|refs| refs.0.len())
        .unwrap_or_default();

    if tool_count > 0 {
        return Err(RigExecutionError::UnsupportedTools { agent, tool_count });
    }

    let history = match session {
        Some(session) => collect_rig_history(world, session, current_user_message)?,
        None => Vec::new(),
    };

    execute_prompt_with_rig(runtime, provider_spec.kind, &agent_spec, prompt, history)
}

pub fn execute_rig_runs(world: &mut World) {
    let pending_runs = {
        let mut query = world.query_filtered::<(
            Entity,
            &RunOwner,
            &RunSession,
            &RunRequest,
            Option<&RunPrompt>,
            &RunStatus,
        ), (With<Run>, Without<RunFinalized>)>();

        query
            .iter(world)
            .filter(|(_, _, _, _, _, status)| **status == RunStatus::Queued)
            .map(|(run, owner, session, request, prompt, _)| PendingRigRun {
                run,
                owner: owner.0,
                session: session.0,
                request_prompt: request.prompt.clone(),
                execution_prompt: prompt
                    .map(|prompt| prompt.0.clone())
                    .unwrap_or_else(|| request.prompt.clone()),
            })
            .collect::<Vec<_>>()
    };

    for pending in pending_runs {
        let should_execute = match world.get::<AgentSpec>(pending.owner) {
            Some(spec) => spec.provider.is_some(),
            None => {
                insert_run_failure(
                    world,
                    pending.run,
                    RigExecutionError::MissingAgentSpec(pending.owner).to_string(),
                );
                continue;
            }
        };

        if !should_execute {
            continue;
        }

        if let Ok(mut entity) = world.get_entity_mut(pending.run) {
            entity.insert(RunStatus::Running);
        }

        match execute_agent_prompt(
            world,
            pending.owner,
            &pending.execution_prompt,
            Some(pending.session),
            Some(&pending.request_prompt),
        ) {
            Ok(text) => insert_run_success(world, pending.run, text),
            Err(error) => insert_run_failure(world, pending.run, error.to_string()),
        }
    }
}

#[derive(Clone)]
struct PendingRigRun {
    run: Entity,
    owner: Entity,
    session: Entity,
    request_prompt: String,
    execution_prompt: String,
}

fn insert_run_success(world: &mut World, run: Entity, text: String) {
    let Ok(mut entity) = world.get_entity_mut(run) else {
        return;
    };

    entity.insert((RunStatus::Completed, RunResultText(text)));
    entity.remove::<RunFailure>();
    entity.remove::<RunStreamBuffer>();
}

fn insert_run_failure(world: &mut World, run: Entity, error: String) {
    let Ok(mut entity) = world.get_entity_mut(run) else {
        return;
    };

    entity.insert((RunStatus::Failed, RunFailure(error)));
    entity.remove::<RunResultText>();
    entity.remove::<RunStreamBuffer>();
}

fn collect_rig_history(
    world: &World,
    session: Entity,
    current_user_message: Option<&str>,
) -> Result<Vec<RigMessage>, RigExecutionError> {
    if world.get_entity(session).is_err() {
        return Err(RigExecutionError::MissingSession(session));
    }

    let mut transcript = session::collect_transcript(world, session);

    if let Some(current_user_message) = current_user_message {
        let should_strip_current = matches!(
            transcript.last(),
            Some((ChatMessageRole::User, text)) if text == current_user_message
        );
        if should_strip_current {
            transcript.pop();
        }
    }

    Ok(transcript
        .into_iter()
        .map(|(role, text)| match role {
            ChatMessageRole::System => RigMessage::system(text),
            ChatMessageRole::User => RigMessage::user(text),
            ChatMessageRole::Assistant => RigMessage::assistant(text),
        })
        .collect())
}

#[allow(deprecated)]
fn execute_prompt_with_rig(
    runtime: Arc<Runtime>,
    provider: ProviderKind,
    agent_spec: &AgentSpec,
    prompt: &str,
    history: Vec<RigMessage>,
) -> Result<String, RigExecutionError> {
    let client = load_provider_client(provider)?;
    let completion = client
        .as_completion()
        .ok_or(RigExecutionError::CompletionUnavailable { provider })?;
    let mut builder = completion.agent(&agent_spec.model);

    if let Some(max_turns) = agent_spec.max_turns {
        builder = builder.default_max_turns(max_turns);
    }

    let agent = builder.build();
    if history.is_empty() {
        runtime
            .block_on(async { agent.prompt(prompt.to_owned()).await })
            .map_err(RigExecutionError::from)
    } else {
        runtime
            .block_on(async { agent.chat(prompt.to_owned(), history).await })
            .map_err(RigExecutionError::from)
    }
}

#[allow(deprecated)]
fn load_provider_client(provider: ProviderKind) -> Result<AnyClient, RigExecutionError> {
    macro_rules! from_env_client {
        ($client:path) => {
            catch_unwind(AssertUnwindSafe(|| AnyClient::new(<$client>::from_env())))
        };
    }

    let result = match provider {
        ProviderKind::Anthropic => from_env_client!(anthropic::Client),
        ProviderKind::Azure => from_env_client!(azure::Client),
        ProviderKind::Cohere => from_env_client!(cohere::Client),
        ProviderKind::DeepSeek => from_env_client!(deepseek::Client),
        ProviderKind::Galadriel => from_env_client!(galadriel::Client),
        ProviderKind::Gemini => from_env_client!(gemini::Client),
        ProviderKind::Groq => from_env_client!(groq::Client),
        ProviderKind::HuggingFace => from_env_client!(huggingface::Client),
        ProviderKind::Hyperbolic => from_env_client!(hyperbolic::Client),
        ProviderKind::Llamafile => from_env_client!(llamafile::Client),
        ProviderKind::Mira => from_env_client!(mira::Client),
        ProviderKind::Mistral => from_env_client!(mistral::Client),
        ProviderKind::Moonshot => from_env_client!(moonshot::Client),
        ProviderKind::Ollama => from_env_client!(ollama::Client),
        ProviderKind::OpenAi => from_env_client!(openai::Client),
        ProviderKind::OpenRouter => from_env_client!(openrouter::Client),
        ProviderKind::Perplexity => from_env_client!(perplexity::Client),
        ProviderKind::Together => from_env_client!(together::Client),
        ProviderKind::XAi => from_env_client!(xai::Client),
    };

    result.map_err(|payload| RigExecutionError::ProviderUnavailable {
        provider,
        error: panic_payload_to_string(payload),
    })
}

fn panic_payload_to_string(payload: Box<dyn Any + Send>) -> String {
    match payload.downcast::<String>() {
        Ok(message) => *message,
        Err(payload) => match payload.downcast::<&'static str>() {
            Ok(message) => (*message).to_string(),
            Err(_) => "unknown provider initialization panic".to_string(),
        },
    }
}