rig/client/
completion.rs

1use crate::agent::AgentBuilder;
2use crate::client::{AsCompletion, ProviderClient};
3use crate::completion::{
4    CompletionError, CompletionModel, CompletionModelDyn, CompletionRequest, CompletionResponse,
5    GetTokenUsage,
6};
7use crate::extractor::ExtractorBuilder;
8use crate::streaming::StreamingCompletionResponse;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use std::future::Future;
12use std::sync::Arc;
13
14/// A provider client with completion capabilities.
15/// Clone is required for conversions between client types.
16pub trait CompletionClient: ProviderClient + Clone {
17    /// The type of CompletionModel used by the client.
18    type CompletionModel: CompletionModel;
19
20    /// Create a completion model with the given name.
21    ///
22    /// # Example with OpenAI
23    /// ```
24    /// use rig::prelude::*;
25    /// use rig::providers::openai::{Client, self};
26    ///
27    /// // Initialize the OpenAI client
28    /// let openai = Client::new("your-open-ai-api-key");
29    ///
30    /// let gpt4 = openai.completion_model(openai::GPT_4);
31    /// ```
32    fn completion_model(&self, model: &str) -> Self::CompletionModel;
33
34    /// Create an agent builder with the given completion model.
35    ///
36    /// # Example with OpenAI
37    /// ```
38    /// use rig::prelude::*;
39    /// use rig::providers::openai::{Client, self};
40    ///
41    /// // Initialize the OpenAI client
42    /// let openai = Client::new("your-open-ai-api-key");
43    ///
44    /// let agent = openai.agent(openai::GPT_4)
45    ///    .preamble("You are comedian AI with a mission to make people laugh.")
46    ///    .temperature(0.0)
47    ///    .build();
48    /// ```
49    fn agent(&self, model: &str) -> AgentBuilder<Self::CompletionModel> {
50        AgentBuilder::new(self.completion_model(model))
51    }
52
53    /// Create an extractor builder with the given completion model.
54    fn extractor<T>(&self, model: &str) -> ExtractorBuilder<Self::CompletionModel, T>
55    where
56        T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync,
57    {
58        ExtractorBuilder::new(self.completion_model(model))
59    }
60}
61
62/// Wraps a CompletionModel in a dyn-compatible way for AgentBuilder.
63#[derive(Clone)]
64pub struct CompletionModelHandle<'a> {
65    pub inner: Arc<dyn CompletionModelDyn + 'a>,
66}
67
68impl CompletionModel for CompletionModelHandle<'_> {
69    type Response = ();
70    type StreamingResponse = ();
71
72    fn completion(
73        &self,
74        request: CompletionRequest,
75    ) -> impl Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>> + Send
76    {
77        self.inner.completion(request)
78    }
79
80    fn stream(
81        &self,
82        request: CompletionRequest,
83    ) -> impl Future<
84        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
85    > + Send {
86        self.inner.stream(request)
87    }
88}
89
90pub trait CompletionClientDyn: ProviderClient {
91    /// Create a completion model with the given name.
92    fn completion_model<'a>(&self, model: &str) -> Box<dyn CompletionModelDyn + 'a>;
93
94    /// Create an agent builder with the given completion model.
95    fn agent<'a>(&self, model: &str) -> AgentBuilder<CompletionModelHandle<'a>>;
96}
97
98impl<T, M, R> CompletionClientDyn for T
99where
100    T: CompletionClient<CompletionModel = M>,
101    M: CompletionModel<StreamingResponse = R> + 'static,
102    R: Clone + Unpin + GetTokenUsage + 'static,
103{
104    fn completion_model<'a>(&self, model: &str) -> Box<dyn CompletionModelDyn + 'a> {
105        Box::new(self.completion_model(model))
106    }
107
108    fn agent<'a>(&self, model: &str) -> AgentBuilder<CompletionModelHandle<'a>> {
109        AgentBuilder::new(CompletionModelHandle {
110            inner: Arc::new(self.completion_model(model)),
111        })
112    }
113}
114
115impl<T> AsCompletion for T
116where
117    T: CompletionClientDyn + Clone + 'static,
118{
119    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
120        Some(Box::new(self.clone()))
121    }
122}