rig-core 0.35.0

An opinionated library for building LLM powered applications.
Documentation
//! Preserves the live request-hook example as provider-local regression coverage.

use anyhow::{Result, anyhow};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};

use rig::agent::{HookAction, PromptHook};
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::{CompletionModel, CompletionResponse, Message, Prompt};
use rig::message::UserContent;
use rig::providers::openai;

use crate::support::assert_nonempty_response;

#[derive(Clone)]
struct SessionIdHook<'a> {
    session_id: &'a str,
    prompt_calls: Arc<AtomicUsize>,
    response_calls: Arc<AtomicUsize>,
    seen_prompt: Arc<Mutex<Option<String>>>,
    seen_response: Arc<Mutex<Option<String>>>,
}

impl<'a, M> PromptHook<M> for SessionIdHook<'a>
where
    M: CompletionModel,
{
    async fn on_completion_call(&self, prompt: &Message, _history: &[Message]) -> HookAction {
        let Message::User { content } = prompt else {
            return HookAction::terminate("expected a user message");
        };

        let prompt_text = content
            .iter()
            .filter_map(|content| match content {
                UserContent::Text(text) => Some(text.text.clone()),
                _ => None,
            })
            .collect::<Vec<_>>()
            .join("\n");

        self.prompt_calls.fetch_add(1, Ordering::SeqCst);
        match self.seen_prompt.lock() {
            Ok(mut seen_prompt) => {
                *seen_prompt = Some(format!("{}:{prompt_text}", self.session_id));
                HookAction::cont()
            }
            Err(_) => HookAction::terminate("prompt hook state unavailable"),
        }
    }

    async fn on_completion_response(
        &self,
        _prompt: &Message,
        response: &CompletionResponse<M::Response>,
    ) -> HookAction {
        self.response_calls.fetch_add(1, Ordering::SeqCst);
        match self.seen_response.lock() {
            Ok(mut seen_response) => {
                *seen_response = Some(format!("{:?}", response.choice));
                HookAction::cont()
            }
            Err(_) => HookAction::terminate("response hook state unavailable"),
        }
    }
}

#[tokio::test]
#[ignore = "requires OPENAI_API_KEY"]
async fn request_hook_records_prompt_and_response() -> Result<()> {
    let agent = openai::Client::from_env()
        .agent(openai::GPT_4O)
        .preamble("You are a comedian here to entertain the user using humour and jokes.")
        .build();

    let hook = SessionIdHook {
        session_id: "abc123",
        prompt_calls: Arc::new(AtomicUsize::new(0)),
        response_calls: Arc::new(AtomicUsize::new(0)),
        seen_prompt: Arc::new(Mutex::new(None)),
        seen_response: Arc::new(Mutex::new(None)),
    };

    let response = agent
        .prompt("Entertain me!")
        .with_hook(hook.clone())
        .await?;

    assert_nonempty_response(&response);
    assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1);
    assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1);

    let seen_prompt = hook
        .seen_prompt
        .lock()
        .map_err(|_| anyhow!("prompt hook state unavailable"))?
        .clone();
    let seen_response = hook
        .seen_response
        .lock()
        .map_err(|_| anyhow!("response hook state unavailable"))?
        .clone();

    assert!(
        seen_prompt
            .as_deref()
            .is_some_and(|prompt| prompt.contains("Entertain me!"))
    );
    assert!(
        seen_response
            .as_deref()
            .is_some_and(|captured| !captured.is_empty())
    );

    Ok(())
}