bevy_rig 0.1.0

Bevy ECS primitives and systems for modeling providers, agents, tools, sessions, runs, and workflows on top of Rig.
Documentation
use bevy_ecs::{hierarchy::ChildOf, prelude::*};
use thiserror::Error;

use crate::{model::ModelSpec, session};

#[derive(Component, Clone, Debug, PartialEq, Eq)]
pub struct Agent;

#[derive(Component, Clone, Debug, PartialEq, Eq)]
pub struct AgentSpec {
    pub name: String,
    pub model: String,
    pub max_turns: Option<usize>,
    pub provider: Option<Entity>,
}

impl AgentSpec {
    pub fn new(name: impl Into<String>, model: impl Into<String>) -> Self {
        Self {
            name: name.into(),
            model: model.into(),
            max_turns: None,
            provider: None,
        }
    }

    pub fn with_provider(mut self, provider: Entity) -> Self {
        self.provider = Some(provider);
        self
    }

    pub fn with_max_turns(mut self, max_turns: usize) -> Self {
        self.max_turns = Some(max_turns);
        self
    }
}

#[derive(Component, Clone, Debug, Default, PartialEq, Eq)]
pub struct AgentToolRefs(pub Vec<Entity>);

#[derive(Component, Clone, Debug, Default, PartialEq, Eq)]
pub struct AgentContextRefs(pub Vec<Entity>);

#[derive(Component, Clone, Copy, Debug, PartialEq, Eq)]
pub struct PrimarySession(pub Entity);

#[derive(Component, Clone, Copy, Debug, PartialEq, Eq)]
pub struct AgentModelRef(pub Entity);

#[derive(Bundle)]
pub struct AgentBundle {
    pub agent: Agent,
    pub spec: AgentSpec,
    pub tools: AgentToolRefs,
    pub contexts: AgentContextRefs,
    pub primary_session: PrimarySession,
}

impl AgentBundle {
    pub fn new(spec: AgentSpec, primary_session: Entity) -> Self {
        Self {
            agent: Agent,
            spec,
            tools: AgentToolRefs::default(),
            contexts: AgentContextRefs::default(),
            primary_session: PrimarySession(primary_session),
        }
    }
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AgentHandles {
    pub agent: Entity,
    pub session: Entity,
}

#[derive(Debug, Error)]
pub enum AgentModelError {
    #[error("agent entity {0:?} does not exist")]
    UnknownAgent(Entity),
    #[error("model entity {0:?} does not exist")]
    UnknownModel(Entity),
    #[error("model entity {0:?} is missing ModelSpec")]
    MissingModelSpec(Entity),
}

#[derive(Debug, Error)]
pub enum AgentLinkError {
    #[error("agent entity {0:?} does not exist")]
    UnknownAgent(Entity),
    #[error("linked entity {0:?} does not exist")]
    UnknownLinkedEntity(Entity),
}

pub fn spawn_agent(world: &mut World, spec: AgentSpec) -> AgentHandles {
    let session = session::spawn_session(world, format!("{} session", spec.name));
    let agent = world.spawn(AgentBundle::new(spec, session)).id();

    AgentHandles { agent, session }
}

pub fn spawn_agent_from_model(
    world: &mut World,
    name: impl Into<String>,
    model: Entity,
) -> Result<AgentHandles, AgentModelError> {
    let model_spec = world.get::<ModelSpec>(model).cloned().ok_or_else(|| {
        if world.get_entity(model).is_err() {
            AgentModelError::UnknownModel(model)
        } else {
            AgentModelError::MissingModelSpec(model)
        }
    })?;
    let provider = world.get::<ChildOf>(model).map(ChildOf::parent);

    let mut spec = AgentSpec::new(name, model_spec.name);
    if let Some(provider) = provider {
        spec = spec.with_provider(provider);
    }

    let handles = spawn_agent(world, spec);
    world.entity_mut(handles.agent).insert(AgentModelRef(model));
    Ok(handles)
}

pub fn bind_model(world: &mut World, agent: Entity, model: Entity) -> Result<(), AgentModelError> {
    if world.get_entity(agent).is_err() {
        return Err(AgentModelError::UnknownAgent(agent));
    }

    let model_spec = world.get::<ModelSpec>(model).cloned().ok_or_else(|| {
        if world.get_entity(model).is_err() {
            AgentModelError::UnknownModel(model)
        } else {
            AgentModelError::MissingModelSpec(model)
        }
    })?;
    let provider = world.get::<ChildOf>(model).map(ChildOf::parent);

    let mut entity = world.entity_mut(agent);
    {
        let mut spec = entity
            .get_mut::<AgentSpec>()
            .expect("agent bundles always include AgentSpec");
        spec.model = model_spec.name;
        spec.provider = provider;
    }
    entity.insert(AgentModelRef(model));

    Ok(())
}

pub fn attach_tool(world: &mut World, agent: Entity, tool: Entity) -> Result<(), AgentLinkError> {
    ensure_entities_exist(world, agent, tool)?;

    let mut agent_entity = world.entity_mut(agent);
    let mut refs = agent_entity
        .get_mut::<AgentToolRefs>()
        .expect("agent bundles always include tool references");
    if !refs.0.contains(&tool) {
        refs.0.push(tool);
    }

    Ok(())
}

pub fn attach_context(
    world: &mut World,
    agent: Entity,
    context: Entity,
) -> Result<(), AgentLinkError> {
    ensure_entities_exist(world, agent, context)?;

    let mut agent_entity = world.entity_mut(agent);
    let mut refs = agent_entity
        .get_mut::<AgentContextRefs>()
        .expect("agent bundles always include context references");
    if !refs.0.contains(&context) {
        refs.0.push(context);
    }

    Ok(())
}

fn ensure_entities_exist(
    world: &mut World,
    agent: Entity,
    entity: Entity,
) -> Result<(), AgentLinkError> {
    if world.get_entity(agent).is_err() {
        return Err(AgentLinkError::UnknownAgent(agent));
    }

    if world.get_entity(entity).is_err() {
        return Err(AgentLinkError::UnknownLinkedEntity(entity));
    }

    Ok(())
}