Skip to main content

rig_vertexai/
completion.rs

1//! All supported models: <https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini>
2
3use super::Client;
4use crate::types::{
5    completion_request::VertexCompletionRequest, completion_response::VertexGenerateContentOutput,
6};
7use rig::completion::{
8    CompletionError, CompletionModel as CompletionModelTrait, CompletionRequest,
9    CompletionResponse, GetTokenUsage,
10};
11use rig::streaming::StreamingCompletionResponse;
12use serde::{Deserialize, Serialize};
13
14/// `gemini-1.5-pro`
15pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
16/// `gemini-1.5-flash`
17pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
18/// `gemini-1.5-pro-latest`
19pub const GEMINI_1_5_PRO_LATEST: &str = "gemini-1.5-pro-latest";
20/// `gemini-1.5-flash-latest`
21pub const GEMINI_1_5_FLASH_LATEST: &str = "gemini-1.5-flash-latest";
22/// `gemini-2.0-flash-exp`
23pub const GEMINI_2_0_FLASH_EXP: &str = "gemini-2.0-flash-exp";
24/// `gemini-2.5-flash-lite`
25pub const GEMINI_2_5_FLASH_LITE: &str = "gemini-2.5-flash-lite";
26/// `gemini-2.5-flash`
27pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
28/// `gemini-2.5-pro`
29pub const GEMINI_2_5_PRO: &str = "gemini-2.5-pro";
30
31#[derive(Clone)]
32pub struct CompletionModel {
33    pub(crate) client: crate::client::Client,
34    pub model: String,
35}
36
37#[derive(Clone, Serialize, Deserialize)]
38pub struct PlaceholderStreamingResponse;
39
40impl GetTokenUsage for PlaceholderStreamingResponse {
41    fn token_usage(&self) -> Option<rig::completion::Usage> {
42        None
43    }
44}
45
46impl CompletionModel {
47    pub fn new(client: Client, model: impl Into<String>) -> Self {
48        Self {
49            client,
50            model: model.into(),
51        }
52    }
53
54    pub fn with_model(client: Client, model: &str) -> Self {
55        Self {
56            client,
57            model: model.into(),
58        }
59    }
60
61    fn model_path(&self) -> Result<String, CompletionError> {
62        let project = self.client.project();
63        let location = self.client.location();
64        Ok(format!(
65            "projects/{project}/locations/{location}/publishers/google/models/{}",
66            self.model
67        ))
68    }
69}
70
71impl CompletionModelTrait for CompletionModel {
72    type Response = VertexGenerateContentOutput;
73    type StreamingResponse = PlaceholderStreamingResponse;
74
75    type Client = Client;
76
77    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
78        Self::new(client.clone(), model.into())
79    }
80
81    async fn completion(
82        &self,
83        request: CompletionRequest,
84    ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
85        tracing::debug!(
86            target: "rig::vertexai",
87            "Vertex AI completion request: {request:?}"
88        );
89
90        let vertex_request = VertexCompletionRequest(request);
91
92        let contents = vertex_request.contents()?;
93        let generation_config = vertex_request.generation_config();
94        let system_instruction = vertex_request.system_instruction();
95        let tools = vertex_request.tools();
96        let tool_config = vertex_request.tool_config();
97        let model_path = self.model_path()?;
98
99        let mut request_builder = self
100            .client
101            .get_inner()
102            .await
103            .map_err(|error| CompletionError::ProviderError(error.to_string()))?
104            .generate_content()
105            .set_model(&model_path)
106            .set_contents(contents);
107
108        if let Some(config) = generation_config {
109            request_builder = request_builder.set_generation_config(config);
110        }
111
112        if let Some(system_instruction) = system_instruction {
113            request_builder = request_builder.set_system_instruction(system_instruction);
114        }
115
116        if let Some(tools) = tools {
117            request_builder = request_builder.set_tools([tools]);
118        }
119
120        if let Some(tool_config) = tool_config {
121            request_builder = request_builder.set_tool_config(tool_config);
122        }
123
124        let response = request_builder
125            .send()
126            .await
127            .map_err(|e| CompletionError::ProviderError(format!("Vertex AI API error: {e}")))?;
128
129        tracing::debug!(
130            target: "rig::vertexai",
131            "Vertex AI completion response: {response:?}"
132        );
133
134        let vertex_output = VertexGenerateContentOutput(response);
135        let completion_response = vertex_output.try_into()?;
136
137        Ok(completion_response)
138    }
139
140    async fn stream(
141        &self,
142        _request: CompletionRequest,
143    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
144        Err(CompletionError::ProviderError(
145            "Streaming is not supported for Vertex AI in this integration".to_string(),
146        ))
147    }
148}