firebase_rs_sdk/ai/models/
generative_model.rs

1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::ai::api::AiService;
6use crate::ai::backend::BackendType;
7use crate::ai::error::AiResult;
8use crate::ai::requests::{ApiSettings, PreparedRequest, RequestFactory, RequestOptions, Task};
9
10/// Port of the Firebase JS SDK `GenerativeModel` class.
11///
12/// Reference: `packages/ai/src/models/generative-model.ts`.
13#[derive(Clone, Debug)]
14pub struct GenerativeModel {
15    api_settings: ApiSettings,
16    model: String,
17    default_request_options: Option<RequestOptions>,
18}
19
20impl GenerativeModel {
21    /// Creates a new generative model bound to the provided `AiService`.
22    ///
23    /// This mirrors the TypeScript constructor, normalising the model name according to the selected
24    /// backend and capturing the service API settings for later requests.
25    pub fn new(
26        service: Arc<AiService>,
27        model_name: impl Into<String>,
28        request_options: Option<RequestOptions>,
29    ) -> AiResult<Self> {
30        let api_settings = service.api_settings()?;
31        let backend_type = service.backend_type();
32        let model = normalize_model_name(model_name.into(), backend_type);
33        Ok(Self {
34            api_settings,
35            model,
36            default_request_options: request_options,
37        })
38    }
39
40    /// Returns the fully qualified model resource identifier.
41    pub fn model(&self) -> &str {
42        &self.model
43    }
44
45    /// Prepares a `generateContent` request using the stored API settings.
46    pub fn prepare_generate_content_request(
47        &self,
48        body: Value,
49        request_options: Option<RequestOptions>,
50    ) -> AiResult<PreparedRequest> {
51        let factory = RequestFactory::new(self.api_settings.clone());
52        let effective_options = request_options.or_else(|| self.default_request_options.clone());
53        factory.construct_request(
54            &self.model,
55            Task::GenerateContent,
56            false,
57            body,
58            effective_options,
59        )
60    }
61}
62
63fn normalize_model_name(model: String, backend_type: BackendType) -> String {
64    match backend_type {
65        BackendType::GoogleAi => normalize_google_ai_model(model),
66        BackendType::VertexAi => normalize_vertex_ai_model(model),
67    }
68}
69
70fn normalize_google_ai_model(model: String) -> String {
71    if model.starts_with("models/") {
72        model
73    } else {
74        format!("models/{model}")
75    }
76}
77
78fn normalize_vertex_ai_model(model: String) -> String {
79    if model.contains('/') {
80        if model.starts_with("models/") {
81            format!("publishers/google/{model}")
82        } else {
83            model
84        }
85    } else {
86        format!("publishers/google/models/{model}")
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::ai::backend::Backend;
94    use crate::ai::public_types::AiOptions;
95    use crate::app::api::initialize_app;
96    use crate::app::{FirebaseAppSettings, FirebaseOptions};
97    use serde_json::json;
98    use std::time::Duration;
99
100    fn unique_settings() -> FirebaseAppSettings {
101        use std::sync::atomic::{AtomicUsize, Ordering};
102        static COUNTER: AtomicUsize = AtomicUsize::new(0);
103        FirebaseAppSettings {
104            name: Some(format!(
105                "gen-model-{}",
106                COUNTER.fetch_add(1, Ordering::SeqCst)
107            )),
108            ..Default::default()
109        }
110    }
111
112    fn init_service(options: FirebaseOptions, backend: Option<Backend>) -> Arc<AiService> {
113        let app = initialize_app(options, Some(unique_settings())).unwrap();
114        match backend {
115            Some(backend) => crate::ai::get_ai(
116                Some(app),
117                Some(AiOptions {
118                    backend: Some(backend),
119                    use_limited_use_app_check_tokens: None,
120                }),
121            )
122            .unwrap(),
123            None => crate::ai::get_ai_service(Some(app)).unwrap(),
124        }
125    }
126
127    #[test]
128    fn normalizes_google_model_name() {
129        let service = init_service(
130            FirebaseOptions {
131                api_key: Some("api".into()),
132                project_id: Some("project".into()),
133                app_id: Some("app".into()),
134                ..Default::default()
135            },
136            None,
137        );
138        let model = GenerativeModel::new(service.clone(), "gemini-pro", None).unwrap();
139        assert_eq!(model.model(), "models/gemini-pro");
140
141        let already_prefixed = GenerativeModel::new(service, "models/gemini-pro", None).unwrap();
142        assert_eq!(already_prefixed.model(), "models/gemini-pro");
143    }
144
145    #[test]
146    fn normalizes_vertex_model_name_and_prepares_request() {
147        let service = init_service(
148            FirebaseOptions {
149                api_key: Some("api".into()),
150                project_id: Some("project".into()),
151                app_id: Some("app".into()),
152                ..Default::default()
153            },
154            Some(Backend::vertex_ai("us-central1")),
155        );
156        let model = GenerativeModel::new(
157            service,
158            "gemini-pro",
159            Some(RequestOptions {
160                timeout: Some(Duration::from_secs(5)),
161                base_url: Some("https://example.com".into()),
162            }),
163        )
164        .unwrap();
165
166        assert_eq!(model.model(), "publishers/google/models/gemini-pro");
167
168        let prepared = model
169            .prepare_generate_content_request(json!({"contents": []}), None)
170            .unwrap();
171        assert_eq!(
172            prepared.url.as_str(),
173            "https://example.com/v1beta/projects/project/locations/us-central1/publishers/google/models/gemini-pro:generateContent"
174        );
175        assert_eq!(prepared.timeout, Duration::from_secs(5));
176    }
177}