forge-guardrails 0.1.0

Foundation types for an LLM-agent workflow framework
Documentation
use forge_guardrails::{
    AnthropicClient, AnyLlmProxyClient, AnyLlmRuntimeClient, ApiFormat, BackendError, ChunkStream,
    ContextDiscoveryError, LLMCallInfo, LLMClient, LLMRequestOptions, LLMResponse,
    LLMResponseEnvelope, LlamafileClient, OllamaClient, SamplingParams, StreamError, TokenUsage,
    ToolSpec,
};
use serde_json::Value;

pub(crate) enum ClientFactory {
    Runtime(AnyLlmRuntimeClient),
    DirectOpenAi {
        base_url: String,
        api_key: Option<String>,
        http_client: reqwest::Client,
        context_tokens: i64,
    },
    DirectAnthropic {
        base_url: String,
        api_key: Option<String>,
        http_client: reqwest::Client,
        context_tokens: i64,
    },
    DirectLlamafile {
        base_url: String,
        mode: String,
        http_client: reqwest::Client,
        context_tokens: i64,
    },
    ManagedLlamafile {
        gguf_path: String,
        base_url: String,
        mode: String,
        http_client: reqwest::Client,
    },
    ManagedOllama {
        model: String,
        http_client: reqwest::Client,
        context_tokens: i64,
    },
}

pub(crate) enum RoutedClient {
    Runtime(AnyLlmRuntimeClient),
    DirectOpenAi(AnyLlmProxyClient),
    DirectAnthropic(AnthropicClient, i64),
    DirectLlamafile(LlamafileClient, i64),
    ManagedLlamafile(LlamafileClient),
    ManagedOllama(OllamaClient),
}

impl ClientFactory {
    pub(crate) fn client_for_model(&self, model: String) -> RoutedClient {
        match self {
            Self::Runtime(client) => RoutedClient::Runtime(client.for_model(model)),
            Self::DirectOpenAi {
                base_url,
                api_key,
                http_client,
                context_tokens,
            } => {
                let mut client = AnyLlmProxyClient::new(model)
                    .with_base_url(base_url)
                    .with_http_client(http_client.clone())
                    .with_context_length(*context_tokens);
                if let Some(api_key) = api_key {
                    client = client.with_api_key(api_key.clone());
                }
                RoutedClient::DirectOpenAi(client)
            }
            Self::DirectAnthropic {
                base_url,
                api_key,
                http_client,
                context_tokens,
            } => RoutedClient::DirectAnthropic(
                AnthropicClient::new(model, api_key.clone())
                    .with_base_url(base_url)
                    .with_http_client(http_client.clone()),
                *context_tokens,
            ),
            Self::DirectLlamafile {
                base_url,
                mode,
                http_client,
                context_tokens,
            } => RoutedClient::DirectLlamafile(
                LlamafileClient::new(model)
                    .with_base_url(base_url)
                    .with_mode(mode)
                    .with_http_client(http_client.clone()),
                *context_tokens,
            ),
            Self::ManagedLlamafile {
                gguf_path,
                base_url,
                mode,
                http_client,
            } => RoutedClient::ManagedLlamafile(
                LlamafileClient::new(gguf_path)
                    .with_base_url(base_url)
                    .with_mode(mode)
                    .with_http_client(http_client.clone()),
            ),
            Self::ManagedOllama {
                model,
                http_client,
                context_tokens,
            } => {
                let client = OllamaClient::new(model.clone()).with_http_client(http_client.clone());
                client.set_num_ctx(Some(*context_tokens));
                RoutedClient::ManagedOllama(client)
            }
        }
    }
}

impl LLMClient for RoutedClient {
    fn api_format(&self) -> ApiFormat {
        match self {
            Self::Runtime(client) => client.api_format(),
            Self::DirectOpenAi(client) => client.api_format(),
            Self::DirectAnthropic(client, _) => client.api_format(),
            Self::DirectLlamafile(client, _) => client.api_format(),
            Self::ManagedLlamafile(client) => client.api_format(),
            Self::ManagedOllama(client) => client.api_format(),
        }
    }

    fn last_usage(&self) -> Option<TokenUsage> {
        match self {
            Self::Runtime(client) => client.last_usage(),
            Self::DirectOpenAi(client) => client.last_usage(),
            Self::DirectAnthropic(client, _) => client.last_usage(),
            Self::DirectLlamafile(client, _) => client.last_usage(),
            Self::ManagedLlamafile(client) => client.last_usage(),
            Self::ManagedOllama(client) => client.last_usage(),
        }
    }

    fn last_call_info(&self) -> Option<LLMCallInfo> {
        match self {
            Self::Runtime(client) => client.last_call_info(),
            Self::DirectOpenAi(client) => client.last_call_info(),
            Self::DirectAnthropic(client, _) => client.last_call_info(),
            Self::DirectLlamafile(client, _) => client.last_call_info(),
            Self::ManagedLlamafile(client) => client.last_call_info(),
            Self::ManagedOllama(client) => client.last_call_info(),
        }
    }

    async fn send(
        &self,
        messages: Vec<Value>,
        tools: Option<Vec<ToolSpec>>,
        sampling: Option<SamplingParams>,
    ) -> Result<LLMResponse, BackendError> {
        match self {
            Self::Runtime(client) => client.send(messages, tools, sampling).await,
            Self::DirectOpenAi(client) => client.send(messages, tools, sampling).await,
            Self::DirectAnthropic(client, _) => client.send(messages, tools, sampling).await,
            Self::DirectLlamafile(client, _) => client.send(messages, tools, sampling).await,
            Self::ManagedLlamafile(client) => client.send(messages, tools, sampling).await,
            Self::ManagedOllama(client) => client.send(messages, tools, sampling).await,
        }
    }

    async fn send_with_options(
        &self,
        messages: Vec<Value>,
        tools: Option<Vec<ToolSpec>>,
        options: LLMRequestOptions,
    ) -> Result<LLMResponse, BackendError> {
        match self {
            Self::Runtime(client) => client.send_with_options(messages, tools, options).await,
            Self::DirectOpenAi(client) => client.send_with_options(messages, tools, options).await,
            Self::DirectAnthropic(client, _) => {
                client.send_with_options(messages, tools, options).await
            }
            Self::DirectLlamafile(client, _) => {
                client.send_with_options(messages, tools, options).await
            }
            Self::ManagedLlamafile(client) => {
                client.send_with_options(messages, tools, options).await
            }
            Self::ManagedOllama(client) => client.send_with_options(messages, tools, options).await,
        }
    }

    async fn send_envelope_with_options(
        &self,
        messages: Vec<Value>,
        tools: Option<Vec<ToolSpec>>,
        options: LLMRequestOptions,
    ) -> Result<LLMResponseEnvelope, BackendError> {
        match self {
            Self::Runtime(client) => {
                client
                    .send_envelope_with_options(messages, tools, options)
                    .await
            }
            Self::DirectOpenAi(client) => {
                client
                    .send_envelope_with_options(messages, tools, options)
                    .await
            }
            Self::DirectAnthropic(client, _) => {
                client
                    .send_envelope_with_options(messages, tools, options)
                    .await
            }
            Self::DirectLlamafile(client, _) => {
                client
                    .send_envelope_with_options(messages, tools, options)
                    .await
            }
            Self::ManagedLlamafile(client) => {
                client
                    .send_envelope_with_options(messages, tools, options)
                    .await
            }
            Self::ManagedOllama(client) => {
                client
                    .send_envelope_with_options(messages, tools, options)
                    .await
            }
        }
    }

    async fn send_stream(
        &self,
        messages: Vec<Value>,
        tools: Option<Vec<ToolSpec>>,
        sampling: Option<SamplingParams>,
    ) -> Result<ChunkStream, StreamError> {
        match self {
            Self::Runtime(client) => client.send_stream(messages, tools, sampling).await,
            Self::DirectOpenAi(client) => client.send_stream(messages, tools, sampling).await,
            Self::DirectAnthropic(client, _) => client.send_stream(messages, tools, sampling).await,
            Self::DirectLlamafile(client, _) => client.send_stream(messages, tools, sampling).await,
            Self::ManagedLlamafile(client) => client.send_stream(messages, tools, sampling).await,
            Self::ManagedOllama(client) => client.send_stream(messages, tools, sampling).await,
        }
    }

    async fn send_stream_with_options(
        &self,
        messages: Vec<Value>,
        tools: Option<Vec<ToolSpec>>,
        options: LLMRequestOptions,
    ) -> Result<ChunkStream, StreamError> {
        match self {
            Self::Runtime(client) => {
                client
                    .send_stream_with_options(messages, tools, options)
                    .await
            }
            Self::DirectOpenAi(client) => {
                client
                    .send_stream_with_options(messages, tools, options)
                    .await
            }
            Self::DirectAnthropic(client, _) => {
                client
                    .send_stream_with_options(messages, tools, options)
                    .await
            }
            Self::DirectLlamafile(client, _) => {
                client
                    .send_stream_with_options(messages, tools, options)
                    .await
            }
            Self::ManagedLlamafile(client) => {
                client
                    .send_stream_with_options(messages, tools, options)
                    .await
            }
            Self::ManagedOllama(client) => {
                client
                    .send_stream_with_options(messages, tools, options)
                    .await
            }
        }
    }

    async fn get_context_length(&self) -> Result<Option<i64>, ContextDiscoveryError> {
        match self {
            Self::Runtime(client) => client.get_context_length().await,
            Self::DirectOpenAi(client) => client.get_context_length().await,
            Self::DirectAnthropic(_, context_tokens) => Ok(Some(*context_tokens)),
            Self::DirectLlamafile(_, context_tokens) => Ok(Some(*context_tokens)),
            Self::ManagedLlamafile(client) => client.get_context_length().await,
            Self::ManagedOllama(client) => client.get_context_length().await,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    #[tokio::test]
    async fn direct_openai_clients_share_transport_but_keep_usage_isolated() {
        let mut upstream = mockito::Server::new_async().await;
        let _mock = upstream
            .mock("POST", "/v1/chat/completions")
            .with_status(200)
            .with_header("content-type", "application/json")
            .with_body(
                json!({
                    "id": "chatcmpl-test",
                    "object": "chat.completion",
                    "created": 0,
                    "model": "request-model",
                    "choices": [{
                        "index": 0,
                        "message": {"role": "assistant", "content": "ok"},
                        "finish_reason": "stop"
                    }],
                    "usage": {"prompt_tokens": 2, "completion_tokens": 3, "total_tokens": 5}
                })
                .to_string(),
            )
            .create_async()
            .await;
        let factory = ClientFactory::DirectOpenAi {
            base_url: upstream.url(),
            api_key: None,
            http_client: reqwest::Client::new(),
            context_tokens: 8192,
        };

        let first = factory.client_for_model("first-model".to_string());
        let second = factory.client_for_model("second-model".to_string());
        first
            .send_with_options(
                vec![json!({"role": "user", "content": "hello"})],
                None,
                LLMRequestOptions::default(),
            )
            .await
            .expect("request");

        assert_eq!(
            first.last_usage(),
            Some(TokenUsage::new(2, 3, 5)),
            "requesting client records usage"
        );
        assert_eq!(
            second.last_usage(),
            None,
            "routed clients keep usage isolated"
        );
    }
}