rig/client/
completion.rs

1use crate::agent::AgentBuilder;
2use crate::client::builder::FinalCompletionResponse;
3use crate::client::{AsCompletion, ProviderClient};
4use crate::completion::{
5    CompletionError, CompletionModel, CompletionModelDyn, CompletionRequest, CompletionResponse,
6    GetTokenUsage,
7};
8use crate::extractor::ExtractorBuilder;
9use crate::streaming::StreamingCompletionResponse;
10use crate::wasm_compat::WasmCompatSend;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::future::Future;
14use std::sync::Arc;
15
16/// A provider client with completion capabilities.
17/// Clone is required for conversions between client types.
18pub trait CompletionClient: ProviderClient + Clone {
19    /// The type of CompletionModel used by the client.
20    type CompletionModel: CompletionModel;
21
22    /// Create a completion model with the given name.
23    ///
24    /// # Example with OpenAI
25    /// ```
26    /// use rig::prelude::*;
27    /// use rig::providers::openai::{Client, self};
28    ///
29    /// // Initialize the OpenAI client
30    /// let openai = Client::new("your-open-ai-api-key");
31    ///
32    /// let gpt4 = openai.completion_model(openai::GPT_4);
33    /// ```
34    fn completion_model(&self, model: &str) -> Self::CompletionModel;
35
36    /// Create an agent builder with the given completion model.
37    ///
38    /// # Example with OpenAI
39    /// ```
40    /// use rig::prelude::*;
41    /// use rig::providers::openai::{Client, self};
42    ///
43    /// // Initialize the OpenAI client
44    /// let openai = Client::new("your-open-ai-api-key");
45    ///
46    /// let agent = openai.agent(openai::GPT_4)
47    ///    .preamble("You are comedian AI with a mission to make people laugh.")
48    ///    .temperature(0.0)
49    ///    .build();
50    /// ```
51    fn agent(&self, model: &str) -> AgentBuilder<Self::CompletionModel> {
52        AgentBuilder::new(self.completion_model(model))
53    }
54
55    /// Create an extractor builder with the given completion model.
56    fn extractor<T>(&self, model: &str) -> ExtractorBuilder<Self::CompletionModel, T>
57    where
58        T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
59    {
60        ExtractorBuilder::new(self.completion_model(model))
61    }
62}
63
64/// Wraps a CompletionModel in a dyn-compatible way for AgentBuilder.
65#[derive(Clone)]
66pub struct CompletionModelHandle<'a> {
67    pub inner: Arc<dyn CompletionModelDyn + 'a>,
68}
69
70impl CompletionModel for CompletionModelHandle<'_> {
71    type Response = ();
72    type StreamingResponse = FinalCompletionResponse;
73
74    fn completion(
75        &self,
76        request: CompletionRequest,
77    ) -> impl Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>> + WasmCompatSend
78    {
79        self.inner.completion(request)
80    }
81
82    fn stream(
83        &self,
84        request: CompletionRequest,
85    ) -> impl Future<
86        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
87    > + WasmCompatSend {
88        self.inner.stream(request)
89    }
90}
91
92pub trait CompletionClientDyn: ProviderClient {
93    /// Create a completion model with the given name.
94    fn completion_model<'a>(&self, model: &str) -> Box<dyn CompletionModelDyn + 'a>;
95
96    /// Create an agent builder with the given completion model.
97    fn agent<'a>(&self, model: &str) -> AgentBuilder<CompletionModelHandle<'a>>;
98}
99
100impl<T, M, R> CompletionClientDyn for T
101where
102    T: CompletionClient<CompletionModel = M>,
103    M: CompletionModel<StreamingResponse = R> + 'static,
104    R: Clone + Unpin + GetTokenUsage + 'static,
105{
106    fn completion_model<'a>(&self, model: &str) -> Box<dyn CompletionModelDyn + 'a> {
107        Box::new(self.completion_model(model))
108    }
109
110    fn agent<'a>(&self, model: &str) -> AgentBuilder<CompletionModelHandle<'a>> {
111        AgentBuilder::new(CompletionModelHandle {
112            inner: Arc::new(self.completion_model(model)),
113        })
114    }
115}
116
117impl<T> AsCompletion for T
118where
119    T: CompletionClientDyn + Clone + 'static,
120{
121    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
122        Some(Box::new(self.clone()))
123    }
124}