firebase_rs_sdk/ai/models/
generative_model.rs1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::ai::api::AiService;
6use crate::ai::backend::BackendType;
7use crate::ai::error::AiResult;
8use crate::ai::requests::{PreparedRequest, RequestOptions, Task};
9
10#[derive(Clone, Debug)]
14pub struct GenerativeModel {
15 service: Arc<AiService>,
16 model: String,
17 default_request_options: Option<RequestOptions>,
18}
19
20impl GenerativeModel {
21 pub fn new(
26 service: Arc<AiService>,
27 model_name: impl Into<String>,
28 request_options: Option<RequestOptions>,
29 ) -> AiResult<Self> {
30 let backend_type = service.backend_type();
31 let model = normalize_model_name(model_name.into(), backend_type);
32 Ok(Self {
33 service,
34 model,
35 default_request_options: request_options,
36 })
37 }
38
39 pub fn model(&self) -> &str {
41 &self.model
42 }
43
44 pub async fn prepare_generate_content_request(
46 &self,
47 body: Value,
48 request_options: Option<RequestOptions>,
49 ) -> AiResult<PreparedRequest> {
50 let factory = self.service.request_factory().await?;
51 let effective_options = request_options.or_else(|| self.default_request_options.clone());
52 factory.construct_request(
53 &self.model,
54 Task::GenerateContent,
55 false,
56 body,
57 effective_options,
58 )
59 }
60}
61
62fn normalize_model_name(model: String, backend_type: BackendType) -> String {
63 match backend_type {
64 BackendType::GoogleAi => normalize_google_ai_model(model),
65 BackendType::VertexAi => normalize_vertex_ai_model(model),
66 }
67}
68
69fn normalize_google_ai_model(model: String) -> String {
70 if model.starts_with("models/") {
71 model
72 } else {
73 format!("models/{model}")
74 }
75}
76
77fn normalize_vertex_ai_model(model: String) -> String {
78 if model.contains('/') {
79 if model.starts_with("models/") {
80 format!("publishers/google/{model}")
81 } else {
82 model
83 }
84 } else {
85 format!("publishers/google/models/{model}")
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::ai::backend::Backend;
93 use crate::ai::public_types::AiOptions;
94 use crate::app::initialize_app;
95 use crate::app::{FirebaseAppSettings, FirebaseOptions};
96 use serde_json::json;
97 use std::time::Duration;
98
99 fn unique_settings() -> FirebaseAppSettings {
100 use std::sync::atomic::{AtomicUsize, Ordering};
101 static COUNTER: AtomicUsize = AtomicUsize::new(0);
102 FirebaseAppSettings {
103 name: Some(format!(
104 "gen-model-{}",
105 COUNTER.fetch_add(1, Ordering::SeqCst)
106 )),
107 ..Default::default()
108 }
109 }
110
111 async fn init_service(options: FirebaseOptions, backend: Option<Backend>) -> Arc<AiService> {
112 let app = initialize_app(options, Some(unique_settings()))
113 .await
114 .unwrap();
115 match backend {
116 Some(backend) => crate::ai::get_ai(
117 Some(app),
118 Some(AiOptions {
119 backend: Some(backend),
120 use_limited_use_app_check_tokens: None,
121 }),
122 )
123 .await
124 .unwrap(),
125 None => crate::ai::get_ai_service(Some(app)).await.unwrap(),
126 }
127 }
128
129 #[tokio::test(flavor = "current_thread")]
130 async fn normalizes_google_model_name() {
131 let service = init_service(
132 FirebaseOptions {
133 api_key: Some("api".into()),
134 project_id: Some("project".into()),
135 app_id: Some("app".into()),
136 ..Default::default()
137 },
138 None,
139 )
140 .await;
141 let model = GenerativeModel::new(service.clone(), "gemini-pro", None).unwrap();
142 assert_eq!(model.model(), "models/gemini-pro");
143
144 let already_prefixed = GenerativeModel::new(service, "models/gemini-pro", None).unwrap();
145 assert_eq!(already_prefixed.model(), "models/gemini-pro");
146 }
147
148 #[tokio::test(flavor = "current_thread")]
149 async fn normalizes_vertex_model_name_and_prepares_request() {
150 let service = init_service(
151 FirebaseOptions {
152 api_key: Some("api".into()),
153 project_id: Some("project".into()),
154 app_id: Some("app".into()),
155 ..Default::default()
156 },
157 Some(Backend::vertex_ai("us-central1")),
158 )
159 .await;
160 let model = GenerativeModel::new(
161 service,
162 "gemini-pro",
163 Some(RequestOptions {
164 timeout: Some(Duration::from_secs(5)),
165 base_url: Some("https://example.com".into()),
166 }),
167 )
168 .unwrap();
169
170 assert_eq!(model.model(), "publishers/google/models/gemini-pro");
171
172 let prepared = model
173 .prepare_generate_content_request(json!({"contents": []}), None)
174 .await
175 .unwrap();
176 assert_eq!(
177 prepared.url.as_str(),
178 "https://example.com/v1beta/projects/project/locations/us-central1/publishers/google/models/gemini-pro:generateContent"
179 );
180 assert_eq!(prepared.timeout, Duration::from_secs(5));
181 }
182}