rig_vertexai/
completion.rs1use super::Client;
4use crate::types::{
5 completion_request::VertexCompletionRequest, completion_response::VertexGenerateContentOutput,
6};
7use rig::completion::{
8 CompletionError, CompletionModel as CompletionModelTrait, CompletionRequest,
9 CompletionResponse, GetTokenUsage,
10};
11use rig::streaming::StreamingCompletionResponse;
12use serde::{Deserialize, Serialize};
13
14pub const GEMINI_1_5_PRO: &str = "gemini-1.5-pro";
16pub const GEMINI_1_5_FLASH: &str = "gemini-1.5-flash";
18pub const GEMINI_1_5_PRO_LATEST: &str = "gemini-1.5-pro-latest";
20pub const GEMINI_1_5_FLASH_LATEST: &str = "gemini-1.5-flash-latest";
22pub const GEMINI_2_0_FLASH_EXP: &str = "gemini-2.0-flash-exp";
24pub const GEMINI_2_5_FLASH_LITE: &str = "gemini-2.5-flash-lite";
26pub const GEMINI_2_5_FLASH: &str = "gemini-2.5-flash";
28pub const GEMINI_2_5_PRO: &str = "gemini-2.5-pro";
30
31#[derive(Clone)]
32pub struct CompletionModel {
33 pub(crate) client: crate::client::Client,
34 pub model: String,
35}
36
37#[derive(Clone, Serialize, Deserialize)]
38pub struct PlaceholderStreamingResponse;
39
40impl GetTokenUsage for PlaceholderStreamingResponse {
41 fn token_usage(&self) -> Option<rig::completion::Usage> {
42 None
43 }
44}
45
46impl CompletionModel {
47 pub fn new(client: Client, model: impl Into<String>) -> Self {
48 Self {
49 client,
50 model: model.into(),
51 }
52 }
53
54 pub fn with_model(client: Client, model: &str) -> Self {
55 Self {
56 client,
57 model: model.into(),
58 }
59 }
60
61 fn model_path(&self) -> Result<String, CompletionError> {
62 let project = self.client.project();
63 let location = self.client.location();
64 Ok(format!(
65 "projects/{project}/locations/{location}/publishers/google/models/{}",
66 self.model
67 ))
68 }
69}
70
71impl CompletionModelTrait for CompletionModel {
72 type Response = VertexGenerateContentOutput;
73 type StreamingResponse = PlaceholderStreamingResponse;
74
75 type Client = Client;
76
77 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
78 Self::new(client.clone(), model.into())
79 }
80
81 async fn completion(
82 &self,
83 request: CompletionRequest,
84 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
85 tracing::debug!(
86 target: "rig::vertexai",
87 "Vertex AI completion request: {request:?}"
88 );
89
90 let vertex_request = VertexCompletionRequest(request);
91
92 let contents = vertex_request.contents()?;
93 let generation_config = vertex_request.generation_config();
94 let system_instruction = vertex_request.system_instruction();
95 let tools = vertex_request.tools();
96 let tool_config = vertex_request.tool_config();
97 let model_path = self.model_path()?;
98
99 let mut request_builder = self
100 .client
101 .get_inner()
102 .await
103 .map_err(|error| CompletionError::ProviderError(error.to_string()))?
104 .generate_content()
105 .set_model(&model_path)
106 .set_contents(contents);
107
108 if let Some(config) = generation_config {
109 request_builder = request_builder.set_generation_config(config);
110 }
111
112 if let Some(system_instruction) = system_instruction {
113 request_builder = request_builder.set_system_instruction(system_instruction);
114 }
115
116 if let Some(tools) = tools {
117 request_builder = request_builder.set_tools([tools]);
118 }
119
120 if let Some(tool_config) = tool_config {
121 request_builder = request_builder.set_tool_config(tool_config);
122 }
123
124 let response = request_builder
125 .send()
126 .await
127 .map_err(|e| CompletionError::ProviderError(format!("Vertex AI API error: {e}")))?;
128
129 tracing::debug!(
130 target: "rig::vertexai",
131 "Vertex AI completion response: {response:?}"
132 );
133
134 let vertex_output = VertexGenerateContentOutput(response);
135 let completion_response = vertex_output.try_into()?;
136
137 Ok(completion_response)
138 }
139
140 async fn stream(
141 &self,
142 _request: CompletionRequest,
143 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
144 Err(CompletionError::ProviderError(
145 "Streaming is not supported for Vertex AI in this integration".to_string(),
146 ))
147 }
148}