model_gateway_rs/sdk/
doubao_vision.rs1use async_trait::async_trait;
2use toolcraft_request::{ByteStream, Request};
3
4use crate::{
5 error::Result,
6 model::{
7 doubao_vision::{DoubaoVisionRequest, DoubaoVisionResponse},
8 vision::{VisionInput, VisionOutput},
9 },
10 sdk::ModelSDK,
11};
12
13pub struct DoubaoVisionSdk {
15 request: Request,
16 model: String,
17}
18
19impl DoubaoVisionSdk {
20 pub fn new(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
21 let mut request = Request::new()?;
22 request.set_base_url(base_url)?;
23 request.set_default_headers(vec![
24 ("Content-Type", "application/json".to_string()),
25 ("Authorization", format!("Bearer {api_key}")),
26 ])?;
27 Ok(Self {
28 request,
29 model: model.to_string(),
30 })
31 }
32
33 pub fn new_with_default_model(api_key: &str, base_url: &str) -> Result<Self> {
35 Self::new(api_key, base_url, "doubao-1-5-thinking-vision-pro-250428")
36 }
37}
38
39#[async_trait]
40impl ModelSDK for DoubaoVisionSdk {
41 type Input = VisionInput;
42 type Output = VisionOutput;
43
44 async fn chat_once(&self, input: Self::Input) -> Result<Self::Output> {
46 let body = DoubaoVisionRequest {
47 model: self.model.clone(),
48 messages: input.messages,
49 thinking: None,
50 stream: None,
51 temperature: None,
52 top_p: None,
53 max_tokens: None,
54 stop: None,
55 };
56 let payload = serde_json::to_value(body)?;
57 let response = self
58 .request
59 .post("chat/completions", &payload, None)
60 .await?;
61 let json: DoubaoVisionResponse = response.json().await?;
62 Ok(json.into())
63 }
64
65 async fn chat_stream(&self, input: Self::Input) -> Result<ByteStream> {
67 let body = DoubaoVisionRequest {
68 model: self.model.clone(),
69 messages: input.messages,
70 thinking: None,
71 stream: Some(true),
72 temperature: None,
73 top_p: None,
74 max_tokens: None,
75 stop: None,
76 };
77 let payload = serde_json::to_value(body)?;
78 let r = self
79 .request
80 .post_stream("chat/completions", &payload, None)
81 .await?;
82 Ok(r)
83 }
84}