model_gateway_rs/sdk/
doubao_vision.rs

1use async_trait::async_trait;
2use toolcraft_request::{ByteStream, Request};
3
4use crate::{
5    error::Result,
6    model::{
7        doubao_vision::{DoubaoVisionRequest, DoubaoVisionResponse},
8        vision::{VisionInput, VisionOutput},
9    },
10    sdk::ModelSDK,
11};
12
13/// DoubaoVision client using wrapped Request.
14pub struct DoubaoVisionSdk {
15    request: Request,
16    model: String,
17}
18
19impl DoubaoVisionSdk {
20    pub fn new(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
21        let mut request = Request::new()?;
22        request.set_base_url(base_url)?;
23        request.set_default_headers(vec![
24            ("Content-Type", "application/json".to_string()),
25            ("Authorization", format!("Bearer {api_key}")),
26        ])?;
27        Ok(Self {
28            request,
29            model: model.to_string(),
30        })
31    }
32
33    /// Create with default model "doubao-1-5-thinking-vision-pro-250428"
34    pub fn new_with_default_model(api_key: &str, base_url: &str) -> Result<Self> {
35        Self::new(api_key, base_url, "doubao-1-5-thinking-vision-pro-250428")
36    }
37}
38
39#[async_trait]
40impl ModelSDK for DoubaoVisionSdk {
41    type Input = VisionInput;
42    type Output = VisionOutput;
43
44    /// Send a vision request and get full response.
45    async fn chat_once(&self, input: Self::Input) -> Result<Self::Output> {
46        let body = DoubaoVisionRequest {
47            model: self.model.clone(),
48            messages: input.messages,
49            thinking: None,
50            stream: None,
51            temperature: None,
52            top_p: None,
53            max_tokens: None,
54            stop: None,
55        };
56        let payload = serde_json::to_value(body)?;
57        let response = self
58            .request
59            .post("chat/completions", &payload, None)
60            .await?;
61        let json: DoubaoVisionResponse = response.json().await?;
62        Ok(json.into())
63    }
64
65    /// Send a vision request and get response stream (SSE).
66    async fn chat_stream(&self, input: Self::Input) -> Result<ByteStream> {
67        let body = DoubaoVisionRequest {
68            model: self.model.clone(),
69            messages: input.messages,
70            thinking: None,
71            stream: Some(true),
72            temperature: None,
73            top_p: None,
74            max_tokens: None,
75            stop: None,
76        };
77        let payload = serde_json::to_value(body)?;
78        let r = self
79            .request
80            .post_stream("chat/completions", &payload, None)
81            .await?;
82        Ok(r)
83    }
84}