ai 0.4.0

Simple to use LLM library for Rust with streaming, tool calling, OAuth helpers, and a lightweight agent loop
Documentation
use crate::AssistantEventStream;
use crate::env_api_keys::get_env_api_key;
use crate::types::{
    AssistantMessage, AssistantMessageEvent, Context, Model, SimpleStreamOptions, StreamOptions,
};
use crate::{Error, Result};

fn has_explicit_api_key(api_key: &Option<String>) -> bool {
    api_key
        .as_deref()
        .is_some_and(|api_key| !api_key.trim().is_empty())
}

fn with_env_api_key(model: &Model, mut options: StreamOptions) -> StreamOptions {
    if !has_explicit_api_key(&options.api_key)
        && let Some(api_key) = get_env_api_key(&model.provider)
    {
        options.api_key = Some(api_key);
    }
    options
}

fn with_env_api_key_simple(model: &Model, mut options: SimpleStreamOptions) -> SimpleStreamOptions {
    options.stream = with_env_api_key(model, options.stream);
    options
}

pub fn stream(
    model: Model,
    context: Context,
    options: Option<StreamOptions>,
) -> Result<AssistantEventStream> {
    let api = model
        .language_api()
        .ok_or_else(|| Error::unsupported_capability(model.provider.clone(), "language models"))?;
    let options = with_env_api_key(&model, options.unwrap_or_default());
    api.stream(model, context, options)
}

pub async fn complete(
    model: Model,
    context: Context,
    options: Option<StreamOptions>,
) -> Result<AssistantMessage> {
    final_message_from_stream(stream(model, context, options)?).await
}

pub fn stream_simple(
    model: Model,
    context: Context,
    options: Option<SimpleStreamOptions>,
) -> Result<AssistantEventStream> {
    let api = model
        .language_api()
        .ok_or_else(|| Error::unsupported_capability(model.provider.clone(), "language models"))?;
    let options = with_env_api_key_simple(&model, options.unwrap_or_default());
    api.stream_simple(model, context, options)
}

pub async fn complete_simple(
    model: Model,
    context: Context,
    options: Option<SimpleStreamOptions>,
) -> Result<AssistantMessage> {
    final_message_from_stream(stream_simple(model, context, options)?).await
}

pub async fn final_message_from_stream(
    mut stream: AssistantEventStream,
) -> Result<AssistantMessage> {
    while let Some(event) = futures::StreamExt::next(&mut stream).await {
        match event? {
            AssistantMessageEvent::Done { message, .. } => return Ok(message),
            AssistantMessageEvent::Error { error, .. } => return Ok(error),
            _ => {}
        }
    }
    Err(Error::StreamClosed)
}

#[cfg(test)]
mod tests {
    use std::sync::{Arc, Mutex};

    use crate::provider::LanguageModelApi;
    use crate::types::{
        AssistantContent, AssistantMessageEvent, ModelCost, ModelInput, StopReason, TextContent,
        Usage,
    };

    use super::*;

    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());

    struct SavedEnv {
        key: &'static str,
        value: Option<String>,
    }

    impl SavedEnv {
        fn capture(key: &'static str) -> Self {
            Self {
                key,
                value: std::env::var(key).ok(),
            }
        }

        fn restore(self) {
            unsafe {
                if let Some(value) = self.value {
                    std::env::set_var(self.key, value);
                } else {
                    std::env::remove_var(self.key);
                }
            }
        }
    }

    #[derive(Clone)]
    struct TestLanguageModelApi {
        api: &'static str,
        observed_key: Arc<Mutex<Option<String>>>,
    }

    impl LanguageModelApi for TestLanguageModelApi {
        fn id(&self) -> &str {
            self.api
        }

        fn stream(
            &self,
            _model: Model,
            _context: Context,
            _options: StreamOptions,
        ) -> Result<AssistantEventStream> {
            panic!("stream should not be called")
        }

        fn stream_simple(
            &self,
            model: Model,
            _context: Context,
            options: SimpleStreamOptions,
        ) -> Result<AssistantEventStream> {
            *self
                .observed_key
                .lock()
                .expect("observed key lock poisoned") = options.stream.api_key.clone();
            Ok(done_stream(&model))
        }
    }

    fn test_model(api: &str, language_api: Option<Arc<dyn LanguageModelApi>>) -> Model {
        Model {
            id: "mock".to_string(),
            name: "mock".to_string(),
            api: api.to_string(),
            provider: "openai".to_string(),
            base_url: "https://example.invalid".to_string(),
            reasoning: false,
            input: vec![ModelInput::Text],
            cost: ModelCost::default(),
            context_window: 8192,
            max_tokens: 2048,
            language_api,
            ..Model::default()
        }
    }

    fn done_stream(model: &Model) -> AssistantEventStream {
        let message = AssistantMessage {
            content: vec![AssistantContent::Text(TextContent {
                text: "ok".to_string(),
                text_signature: None,
            })],
            api: model.api.clone(),
            provider: model.provider.clone(),
            model: model.id.clone(),
            response_model: None,
            response_id: None,
            diagnostics: Vec::new(),
            usage: Usage::default(),
            stop_reason: StopReason::Stop,
            error_message: None,
            timestamp: crate::utils::time::now_millis(),
        };
        let reason = message.stop_reason;
        let (mut sender, stream) = crate::create_assistant_message_event_stream();
        sender.push(AssistantMessageEvent::Done {
            reason,
            message: message.clone(),
        });
        stream
    }

    #[tokio::test(flavor = "current_thread")]
    async fn stream_simple_injects_env_api_key_before_provider_dispatch() {
        let _guard = ENV_LOCK.lock().await;
        let openai = SavedEnv::capture("OPENAI_API_KEY");
        unsafe {
            std::env::set_var("OPENAI_API_KEY", "env-openai-key");
        }

        let observed_key = Arc::new(Mutex::new(None));
        let api = Arc::new(TestLanguageModelApi {
            api: "stream-env-key-test",
            observed_key: Arc::clone(&observed_key),
        });

        let events = stream_simple(
            test_model("stream-env-key-test", Some(api)),
            Context::default(),
            None,
        )
        .expect("stream_simple should dispatch");
        let _message = crate::stream::final_message_from_stream(events)
            .await
            .expect("stream result");

        assert_eq!(
            observed_key
                .lock()
                .expect("observed key lock poisoned")
                .as_deref(),
            Some("env-openai-key")
        );

        openai.restore();
    }

    #[tokio::test(flavor = "current_thread")]
    async fn explicit_api_key_takes_precedence_over_env_api_key() {
        let _guard = ENV_LOCK.lock().await;
        let openai = SavedEnv::capture("OPENAI_API_KEY");
        unsafe {
            std::env::set_var("OPENAI_API_KEY", "env-openai-key");
        }

        let observed_key = Arc::new(Mutex::new(None));
        let api = Arc::new(TestLanguageModelApi {
            api: "stream-explicit-key-test",
            observed_key: Arc::clone(&observed_key),
        });

        let options = SimpleStreamOptions {
            stream: StreamOptions {
                api_key: Some("explicit-key".to_string()),
                ..Default::default()
            },
            ..Default::default()
        };
        let events = stream_simple(
            test_model("stream-explicit-key-test", Some(api)),
            Context::default(),
            Some(options),
        )
        .expect("stream_simple should dispatch");
        let _message = crate::stream::final_message_from_stream(events)
            .await
            .expect("stream result");

        assert_eq!(
            observed_key
                .lock()
                .expect("observed key lock poisoned")
                .as_deref(),
            Some("explicit-key")
        );

        openai.restore();
    }

    #[test]
    fn stream_reports_model_without_language_api() {
        let model = test_model("missing-api-provider-test", None);
        let error = match stream(model, Context::default(), None) {
            Ok(_) => panic!("expected missing provider error"),
            Err(error) => error,
        };

        assert!(matches!(
            &error,
            Error::UnsupportedCapability {
                provider,
                capability: "language models",
            } if provider == "openai"
        ));
    }
}