firebase_rs_sdk/ai/models/
generative_model.rs1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::ai::api::AiService;
6use crate::ai::backend::BackendType;
7use crate::ai::error::AiResult;
8use crate::ai::requests::{ApiSettings, PreparedRequest, RequestFactory, RequestOptions, Task};
9
10#[derive(Clone, Debug)]
14pub struct GenerativeModel {
15 api_settings: ApiSettings,
16 model: String,
17 default_request_options: Option<RequestOptions>,
18}
19
20impl GenerativeModel {
21 pub fn new(
26 service: Arc<AiService>,
27 model_name: impl Into<String>,
28 request_options: Option<RequestOptions>,
29 ) -> AiResult<Self> {
30 let api_settings = service.api_settings()?;
31 let backend_type = service.backend_type();
32 let model = normalize_model_name(model_name.into(), backend_type);
33 Ok(Self {
34 api_settings,
35 model,
36 default_request_options: request_options,
37 })
38 }
39
40 pub fn model(&self) -> &str {
42 &self.model
43 }
44
45 pub fn prepare_generate_content_request(
47 &self,
48 body: Value,
49 request_options: Option<RequestOptions>,
50 ) -> AiResult<PreparedRequest> {
51 let factory = RequestFactory::new(self.api_settings.clone());
52 let effective_options = request_options.or_else(|| self.default_request_options.clone());
53 factory.construct_request(
54 &self.model,
55 Task::GenerateContent,
56 false,
57 body,
58 effective_options,
59 )
60 }
61}
62
63fn normalize_model_name(model: String, backend_type: BackendType) -> String {
64 match backend_type {
65 BackendType::GoogleAi => normalize_google_ai_model(model),
66 BackendType::VertexAi => normalize_vertex_ai_model(model),
67 }
68}
69
70fn normalize_google_ai_model(model: String) -> String {
71 if model.starts_with("models/") {
72 model
73 } else {
74 format!("models/{model}")
75 }
76}
77
78fn normalize_vertex_ai_model(model: String) -> String {
79 if model.contains('/') {
80 if model.starts_with("models/") {
81 format!("publishers/google/{model}")
82 } else {
83 model
84 }
85 } else {
86 format!("publishers/google/models/{model}")
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::ai::backend::Backend;
94 use crate::ai::public_types::AiOptions;
95 use crate::app::api::initialize_app;
96 use crate::app::{FirebaseAppSettings, FirebaseOptions};
97 use serde_json::json;
98 use std::time::Duration;
99
100 fn unique_settings() -> FirebaseAppSettings {
101 use std::sync::atomic::{AtomicUsize, Ordering};
102 static COUNTER: AtomicUsize = AtomicUsize::new(0);
103 FirebaseAppSettings {
104 name: Some(format!(
105 "gen-model-{}",
106 COUNTER.fetch_add(1, Ordering::SeqCst)
107 )),
108 ..Default::default()
109 }
110 }
111
112 fn init_service(options: FirebaseOptions, backend: Option<Backend>) -> Arc<AiService> {
113 let app = initialize_app(options, Some(unique_settings())).unwrap();
114 match backend {
115 Some(backend) => crate::ai::get_ai(
116 Some(app),
117 Some(AiOptions {
118 backend: Some(backend),
119 use_limited_use_app_check_tokens: None,
120 }),
121 )
122 .unwrap(),
123 None => crate::ai::get_ai_service(Some(app)).unwrap(),
124 }
125 }
126
127 #[test]
128 fn normalizes_google_model_name() {
129 let service = init_service(
130 FirebaseOptions {
131 api_key: Some("api".into()),
132 project_id: Some("project".into()),
133 app_id: Some("app".into()),
134 ..Default::default()
135 },
136 None,
137 );
138 let model = GenerativeModel::new(service.clone(), "gemini-pro", None).unwrap();
139 assert_eq!(model.model(), "models/gemini-pro");
140
141 let already_prefixed = GenerativeModel::new(service, "models/gemini-pro", None).unwrap();
142 assert_eq!(already_prefixed.model(), "models/gemini-pro");
143 }
144
145 #[test]
146 fn normalizes_vertex_model_name_and_prepares_request() {
147 let service = init_service(
148 FirebaseOptions {
149 api_key: Some("api".into()),
150 project_id: Some("project".into()),
151 app_id: Some("app".into()),
152 ..Default::default()
153 },
154 Some(Backend::vertex_ai("us-central1")),
155 );
156 let model = GenerativeModel::new(
157 service,
158 "gemini-pro",
159 Some(RequestOptions {
160 timeout: Some(Duration::from_secs(5)),
161 base_url: Some("https://example.com".into()),
162 }),
163 )
164 .unwrap();
165
166 assert_eq!(model.model(), "publishers/google/models/gemini-pro");
167
168 let prepared = model
169 .prepare_generate_content_request(json!({"contents": []}), None)
170 .unwrap();
171 assert_eq!(
172 prepared.url.as_str(),
173 "https://example.com/v1beta/projects/project/locations/us-central1/publishers/google/models/gemini-pro:generateContent"
174 );
175 assert_eq!(prepared.timeout, Duration::from_secs(5));
176 }
177}