rust_genai/
operations.rs

1//! Operations API surface.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use reqwest::header::{HeaderName, HeaderValue};
7use rust_genai_types::operations::{
8    GetOperationConfig, ListOperationsConfig, ListOperationsResponse, Operation,
9};
10
11use crate::client::{Backend, ClientInner};
12use crate::error::{Error, Result};
13
14#[derive(Clone)]
15pub struct Operations {
16    pub(crate) inner: Arc<ClientInner>,
17}
18
19impl Operations {
20    pub(crate) fn new(inner: Arc<ClientInner>) -> Self {
21        Self { inner }
22    }
23
24    /// 获取操作状态。
25    pub async fn get(&self, name: impl AsRef<str>) -> Result<Operation> {
26        self.get_with_config(name, GetOperationConfig::default())
27            .await
28    }
29
30    /// 获取操作状态(带配置)。
31    pub async fn get_with_config(
32        &self,
33        name: impl AsRef<str>,
34        mut config: GetOperationConfig,
35    ) -> Result<Operation> {
36        let http_options = config.http_options.take();
37        let name = normalize_operation_name(&self.inner, name.as_ref())?;
38        let url = build_operation_url(&self.inner, &name, http_options.as_ref())?;
39        let mut request = self.inner.http.get(url);
40        request = apply_http_options(request, http_options.as_ref())?;
41
42        let response = self.inner.send(request).await?;
43        if !response.status().is_success() {
44            return Err(Error::ApiError {
45                status: response.status().as_u16(),
46                message: response.text().await.unwrap_or_default(),
47            });
48        }
49        Ok(response.json::<Operation>().await?)
50    }
51
52    /// 列出操作。
53    pub async fn list(&self) -> Result<ListOperationsResponse> {
54        self.list_with_config(ListOperationsConfig::default()).await
55    }
56
57    /// 列出操作(带配置)。
58    pub async fn list_with_config(
59        &self,
60        mut config: ListOperationsConfig,
61    ) -> Result<ListOperationsResponse> {
62        let http_options = config.http_options.take();
63        let url = build_operations_list_url(&self.inner, http_options.as_ref())?;
64        let url = add_list_query_params(url, &config)?;
65        let mut request = self.inner.http.get(url);
66        request = apply_http_options(request, http_options.as_ref())?;
67
68        let response = self.inner.send(request).await?;
69        if !response.status().is_success() {
70            return Err(Error::ApiError {
71                status: response.status().as_u16(),
72                message: response.text().await.unwrap_or_default(),
73            });
74        }
75        Ok(response.json::<ListOperationsResponse>().await?)
76    }
77
78    /// 列出所有操作(自动翻页)。
79    pub async fn all(&self) -> Result<Vec<Operation>> {
80        self.all_with_config(ListOperationsConfig::default()).await
81    }
82
83    /// 列出所有操作(带配置,自动翻页)。
84    pub async fn all_with_config(
85        &self,
86        mut config: ListOperationsConfig,
87    ) -> Result<Vec<Operation>> {
88        let mut ops = Vec::new();
89        let http_options = config.http_options.clone();
90        loop {
91            let mut page_config = config.clone();
92            page_config.http_options = http_options.clone();
93            let response = self.list_with_config(page_config).await?;
94            if let Some(items) = response.operations {
95                ops.extend(items);
96            }
97            match response.next_page_token {
98                Some(token) if !token.is_empty() => {
99                    config.page_token = Some(token);
100                }
101                _ => break,
102            }
103        }
104        Ok(ops)
105    }
106
107    /// 等待操作完成(轮询)。
108    pub async fn wait(&self, mut operation: Operation) -> Result<Operation> {
109        let name = operation.name.clone().ok_or_else(|| Error::InvalidConfig {
110            message: "Operation name is empty".into(),
111        })?;
112        while !operation.done.unwrap_or(false) {
113            tokio::time::sleep(Duration::from_secs(5)).await;
114            operation = self.get(&name).await?;
115        }
116        Ok(operation)
117    }
118}
119
120fn normalize_operation_name(inner: &ClientInner, name: &str) -> Result<String> {
121    match inner.config.backend {
122        Backend::GeminiApi => {
123            if name.starts_with("operations/") || name.starts_with("models/") {
124                Ok(name.to_string())
125            } else {
126                Ok(format!("operations/{name}"))
127            }
128        }
129        Backend::VertexAi => {
130            let vertex =
131                inner
132                    .config
133                    .vertex_config
134                    .as_ref()
135                    .ok_or_else(|| Error::InvalidConfig {
136                        message: "Vertex config missing".into(),
137                    })?;
138            if name.starts_with("projects/") {
139                Ok(name.to_string())
140            } else if name.starts_with("locations/") {
141                Ok(format!("projects/{}/{}", vertex.project, name))
142            } else if name.starts_with("operations/") {
143                Ok(format!(
144                    "projects/{}/locations/{}/{}",
145                    vertex.project, vertex.location, name
146                ))
147            } else {
148                Ok(format!(
149                    "projects/{}/locations/{}/operations/{}",
150                    vertex.project, vertex.location, name
151                ))
152            }
153        }
154    }
155}
156
157fn build_operation_url(
158    inner: &ClientInner,
159    name: &str,
160    http_options: Option<&rust_genai_types::http::HttpOptions>,
161) -> Result<String> {
162    let base = http_options
163        .and_then(|opts| opts.base_url.as_deref())
164        .unwrap_or(&inner.api_client.base_url);
165    let version = http_options
166        .and_then(|opts| opts.api_version.as_deref())
167        .unwrap_or(&inner.api_client.api_version);
168    Ok(format!("{base}{version}/{name}"))
169}
170
171fn build_operations_list_url(
172    inner: &ClientInner,
173    http_options: Option<&rust_genai_types::http::HttpOptions>,
174) -> Result<String> {
175    let base = http_options
176        .and_then(|opts| opts.base_url.as_deref())
177        .unwrap_or(&inner.api_client.base_url);
178    let version = http_options
179        .and_then(|opts| opts.api_version.as_deref())
180        .unwrap_or(&inner.api_client.api_version);
181    let url = match inner.config.backend {
182        Backend::GeminiApi => format!("{base}{version}/operations"),
183        Backend::VertexAi => {
184            let vertex =
185                inner
186                    .config
187                    .vertex_config
188                    .as_ref()
189                    .ok_or_else(|| Error::InvalidConfig {
190                        message: "Vertex config missing".into(),
191                    })?;
192            format!(
193                "{base}{version}/projects/{}/locations/{}/operations",
194                vertex.project, vertex.location
195            )
196        }
197    };
198    Ok(url)
199}
200
201fn add_list_query_params(url: String, config: &ListOperationsConfig) -> Result<String> {
202    let mut url = reqwest::Url::parse(&url).map_err(|err| Error::InvalidConfig {
203        message: err.to_string(),
204    })?;
205    {
206        let mut pairs = url.query_pairs_mut();
207        if let Some(page_size) = config.page_size {
208            pairs.append_pair("pageSize", &page_size.to_string());
209        }
210        if let Some(page_token) = &config.page_token {
211            pairs.append_pair("pageToken", page_token);
212        }
213        if let Some(filter) = &config.filter {
214            pairs.append_pair("filter", filter);
215        }
216    }
217    Ok(url.to_string())
218}
219
220fn apply_http_options(
221    mut request: reqwest::RequestBuilder,
222    http_options: Option<&rust_genai_types::http::HttpOptions>,
223) -> Result<reqwest::RequestBuilder> {
224    if let Some(options) = http_options {
225        if let Some(timeout) = options.timeout {
226            request = request.timeout(Duration::from_millis(timeout));
227        }
228        if let Some(headers) = &options.headers {
229            for (key, value) in headers {
230                let name =
231                    HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
232                        message: format!("Invalid header name: {key}"),
233                    })?;
234                let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
235                    message: format!("Invalid header value for {key}"),
236                })?;
237                request = request.header(name, value);
238            }
239        }
240    }
241    Ok(request)
242}