use std::sync::Arc;
use async_trait::async_trait;
use futures_util::StreamExt;
use serde_json::Value;
use crate::llm_provider::vertexai::client::VertexAIClient;
use crate::model_api_provider::provider::{
ModelApiProvider, ProviderBody, ProviderRequest, ProviderResponse, filter_request_headers,
filter_response_headers, parse_stream_flag,
};
use crate::serve_config::{ConfigError, ProviderConfig};
#[derive(Clone)]
pub struct GenerateContentClient {
model_id: String,
upstream_model: String,
client: VertexAIClient,
}
impl GenerateContentClient {
pub fn new(model_id: String, upstream_model: String, client: VertexAIClient) -> Self {
Self {
model_id,
upstream_model,
client,
}
}
}
#[async_trait]
impl ModelApiProvider for GenerateContentClient {
fn model_id(&self) -> &str {
&self.model_id
}
async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse, anyhow::Error> {
let stream = parse_stream_flag(&req.body);
let headers = filter_request_headers(req.headers);
let response = self
.client
.post_json_with_headers(&self.upstream_model, req.body.to_vec(), stream, headers)
.await?;
let status = response.status();
let mut resp_headers = filter_response_headers(response.headers());
if stream {
resp_headers.remove(axum::http::header::CONTENT_LENGTH);
let body_stream = response.bytes_stream().map(|chunk| match chunk {
Ok(bytes) => Ok(bytes),
Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
});
Ok(ProviderResponse {
status,
headers: resp_headers,
body: ProviderBody::Stream(Box::pin(body_stream)),
})
} else {
let bytes = response.bytes().await?;
Ok(ProviderResponse {
status,
headers: resp_headers,
body: ProviderBody::Full(bytes),
})
}
}
fn extract_request_id(&self, payload_json: &Value) -> Option<String> {
extract_request_id_json(payload_json)
}
fn extract_usage(&self, payload_json: &Value) -> Option<Value> {
extract_usage_json(payload_json)
}
fn inject_usage(&self, payload_json: &mut Value, usage: Value) -> bool {
inject_usage_json(payload_json, usage)
}
}
fn extract_request_id_json(payload_json: &Value) -> Option<String> {
payload_json
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
payload_json
.get("response")
.and_then(|response| response.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
})
}
fn extract_usage_json(payload_json: &Value) -> Option<Value> {
non_null_usage(payload_json.get("usageMetadata")).or_else(|| {
non_null_usage(
payload_json
.get("response")
.and_then(|response| response.get("usageMetadata")),
)
})
}
fn inject_usage_json(payload_json: &mut Value, usage: Value) -> bool {
let Some(obj) = payload_json.as_object_mut() else {
return false;
};
if obj.contains_key("usageMetadata") {
obj.insert("usageMetadata".to_string(), usage);
return true;
}
if let Some(response) = obj.get_mut("response").and_then(Value::as_object_mut) {
response.insert("usageMetadata".to_string(), usage);
return true;
}
false
}
fn non_null_usage(value: Option<&Value>) -> Option<Value> {
value.filter(|usage| !usage.is_null()).cloned()
}
#[cfg(test)]
mod tests {
use super::{extract_request_id_json, extract_usage_json, inject_usage_json};
use serde_json::json;
#[test]
fn extract_request_id_json_prefers_top_level_id() {
let payload = json!({"id": "gen_top", "response": {"id": "gen_nested"}});
let request_id = extract_request_id_json(&payload);
assert_eq!(request_id.as_deref(), Some("gen_top"));
}
#[test]
fn extract_request_id_json_supports_nested_response_id() {
let payload = json!({"response": {"id": "gen_nested"}});
let request_id = extract_request_id_json(&payload);
assert_eq!(request_id.as_deref(), Some("gen_nested"));
}
#[test]
fn extract_usage_json_reads_nested_usage_metadata() {
let payload = json!({"response": {"usageMetadata": {"totalTokenCount": 42}}});
let usage = extract_usage_json(&payload);
assert_eq!(usage, Some(json!({"totalTokenCount": 42})));
}
#[test]
fn extract_usage_json_ignores_null_usage_metadata() {
let payload = json!({"usageMetadata": null});
let usage = extract_usage_json(&payload);
assert_eq!(usage, None);
}
#[test]
fn inject_usage_json_updates_nested_response_usage_metadata() {
let mut payload = json!({"response": {"usageMetadata": {"totalTokenCount": 1}}});
let new_usage = json!({"totalTokenCount": 99});
let injected = inject_usage_json(&mut payload, new_usage.clone());
assert!(injected);
assert_eq!(payload["response"]["usageMetadata"], new_usage);
}
}
pub fn build_client(provider: &ProviderConfig) -> Result<Arc<dyn ModelApiProvider>, ConfigError> {
if provider.provider_type != "generate_content" {
return Err(ConfigError::UnknownProviderType(
provider.provider_type.clone(),
));
}
let project_id = provider
.params
.get("project_id")
.cloned()
.ok_or_else(|| ConfigError::InvalidProvider("project_id is required".to_string()))?;
let location = provider
.params
.get("location")
.cloned()
.unwrap_or_else(|| "global".to_string());
let credentials_file = provider
.params
.get("credentials_file")
.cloned()
.ok_or_else(|| ConfigError::InvalidProvider("credentials_file is required".to_string()))?;
let upstream_model = provider
.params
.get("model")
.cloned()
.ok_or_else(|| ConfigError::InvalidProvider("model is required".to_string()))?;
let client = VertexAIClient::new(project_id, location, credentials_file)
.map_err(|err| ConfigError::InvalidProvider(err.to_string()))?;
Ok(Arc::new(GenerateContentClient::new(
provider.model_id.clone(),
upstream_model,
client,
)))
}