openai_ng/proto/
chat.rs

1use crate::client::Client;
2use crate::error::*;
3use crate::proto::tool::*;
4
5use base64::Engine;
6use futures::StreamExt;
7use http::{
8    header::{self, HeaderValue},
9    Method,
10};
11use reqwest::Body;
12use serde::de::{Deserialize, IntoDeserializer};
13use serde_with::skip_serializing_none;
14use smart_default::SmartDefault;
15use tokio::sync::mpsc::Receiver;
16use tracing::*;
17
18use std::time::Duration;
19
20#[skip_serializing_none]
21#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, SmartDefault)]
22pub struct ChatCompletionRequest {
23    pub model: String,
24    pub messages: Vec<Message>,
25    #[serde(default)]
26    pub tools: Vec<ToolCall>,
27    pub max_tokens: Option<u64>,
28    pub temperature: Option<f64>,
29    pub top_p: Option<f64>,
30    pub n: Option<u64>,
31    pub stream: Option<bool>,
32    pub stop: Option<Stop>,
33    pub frequency_penalty: Option<f64>,
34    pub response_format: Option<ResponseFormat>,
35}
36
37pub enum ChatCompletionResult {
38    Response(ChatCompletionResponse),
39    Delta(Receiver<Result<ChatCompletionStreamData>>),
40}
41
42impl ChatCompletionRequest {
43    pub async fn call_once(
44        &self,
45        client: &Client,
46        timeout: Option<Duration>,
47    ) -> Result<ChatCompletionResponse> {
48        let uri = "chat/completions";
49
50        let rep = client
51            .call_impl(
52                Method::POST,
53                uri,
54                vec![(
55                    header::CONTENT_TYPE,
56                    HeaderValue::from_str("application/json")?,
57                )],
58                Some(Body::from(serde_json::to_vec(&self)?)),
59                None,
60                timeout,
61            )
62            .await?;
63
64        let status = rep.status();
65
66        let rep = serde_json::from_slice::<serde_json::Value>(rep.bytes().await?.as_ref())?;
67
68        for l in serde_json::to_string_pretty(&rep)?.split("\n") {
69            if status.is_success() {
70                tracing::trace!("REP: {}", l);
71            } else {
72                tracing::error!("REP: {}", l);
73            }
74        }
75
76        if status.is_success() {
77            let rep: ChatCompletionResponse = serde_json::from_value(rep)?;
78            Ok(rep)
79        } else {
80            error!("chat completion failed");
81            Err(Error::ApiError(status.as_u16()))
82        }
83    }
84
85    pub async fn call_stream(
86        &self,
87        client: &Client,
88        timeout: Option<Duration>,
89    ) -> Result<Receiver<Result<ChatCompletionStreamData>>> {
90        let uri = "chat/completions";
91
92        let rep = client
93            .call_impl(
94                Method::POST,
95                uri,
96                vec![(
97                    header::CONTENT_TYPE,
98                    HeaderValue::from_str("application/json")?,
99                )],
100                Some(Body::from(serde_json::to_vec(&self)?)),
101                None,
102                timeout,
103            )
104            .await?;
105
106        let (tx, rx) = tokio::sync::mpsc::channel(1);
107
108        tokio::spawn(async move {
109            let mut stack = vec![];
110            let mut stream = rep.bytes_stream();
111
112            let s_tag = "data: ".as_bytes();
113            let s_tag_len = s_tag.len();
114            let e_tag = "\n\n".as_bytes();
115            let e_tag_len = e_tag.len();
116
117            while let Some(r) = stream.next().await {
118                let chunk = match r {
119                    Ok(r) => r,
120                    Err(e) => {
121                        error!("stream return with error: {:?}", e);
122                        break;
123                    }
124                };
125
126                trace!("recv chunk {} bytes", chunk.len());
127
128                for b in chunk.as_ref() {
129                    stack.push(*b);
130                    if stack.len() >= e_tag_len + s_tag_len {
131                        let slice = &stack[stack.len() - e_tag_len..];
132
133                        if slice == e_tag {
134                            let mut data = vec![];
135                            std::mem::swap(&mut data, &mut stack);
136
137                            let data =
138                                String::from_utf8_lossy(&data[s_tag_len..data.len() - e_tag_len]);
139
140                            if data.find("[DONE]").is_some() {
141                                trace!("met [DONE], data={}", data);
142                                continue;
143                            }
144
145                            match serde_json::from_str::<ChatCompletionStreamData>(&data) {
146                                Err(e) => {
147                                    error!("failed to parse data: error={:?}, data={}", e, data);
148                                    tx.send(Err(e.into())).await.map_err(|_| {
149                                        error!("failed to send error message to chat receiver");
150                                        Error::SendMessage
151                                    })?;
152                                }
153                                Ok(data) => {
154                                    trace!("found data event from stream");
155                                    for l in serde_json::to_string_pretty(&data)?.lines() {
156                                        trace!("DATA: {}", l);
157                                    }
158                                    tx.send(Ok(data)).await.map_err(|_| {
159                                        error!("failed to send data message to chat receiver");
160                                        Error::SendMessage
161                                    })?;
162                                }
163                            }
164                        }
165                    }
166                }
167            }
168            trace!("stream thread quit, with stack.len()={}", stack.len());
169            Result::Ok(())
170        });
171
172        Ok(rx)
173    }
174
175    pub async fn call(
176        &self,
177        client: &crate::client::Client,
178        timeout: Option<std::time::Duration>,
179    ) -> Result<ChatCompletionResult> {
180        match self.stream {
181            Some(true) => Ok(ChatCompletionResult::Delta(
182                self.call_stream(client, timeout).await?,
183            )),
184            _ => Ok(ChatCompletionResult::Response(
185                self.call_once(client, timeout).await?,
186            )),
187        }
188    }
189}
190
191#[skip_serializing_none]
192#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
193pub struct ResponseFormat {
194    #[serde(rename = "type")]
195    typ: ResponseType,
196}
197
198#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
199#[allow(non_camel_case_types)]
200pub enum ResponseType {
201    json_object,
202}
203
204#[derive(Debug, Clone, SmartDefault)]
205pub struct ChatCompletionRequestBuilder {
206    model: Option<String>,
207    messages: Vec<Message>,
208    tools: Vec<ToolCall>,
209    max_tokens: Option<u64>,
210    temperature: Option<f64>,
211    top_p: Option<f64>,
212    n: Option<u64>,
213    stream: Option<bool>,
214    stop: Option<Stop>,
215    frequency_penalty: Option<f64>,
216    response_format: Option<ResponseFormat>,
217}
218
219impl ChatCompletionRequestBuilder {
220    pub fn with_reponse_format(mut self, format: ResponseType) -> Self {
221        self.response_format = Some(ResponseFormat { typ: format });
222        self
223    }
224
225    pub fn with_model(mut self, model: impl Into<String>) -> Self {
226        self.model = Some(model.into());
227        self
228    }
229
230    pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
231        self.messages.extend(messages);
232        self
233    }
234
235    pub fn add_message(mut self, msg: Message) -> Self {
236        self.messages.push(msg);
237        self
238    }
239
240    pub fn with_tool(mut self, tool: impl Into<ToolCall>) -> Self {
241        self.tools.push(tool.into());
242        self
243    }
244
245    pub fn with_tools<T>(mut self, tools: impl IntoIterator<Item = T>) -> Self
246    where
247        T: Into<ToolCall>,
248    {
249        self.tools.extend(tools.into_iter().map(|t| t.into()));
250        self
251    }
252
253    pub fn add_tool(self, tool: impl Into<ToolCall>) -> Self {
254        self.with_tool(tool)
255    }
256
257    pub fn with_max_tokens(mut self, max_tokens: u64) -> Self {
258        self.max_tokens = Some(max_tokens);
259        self
260    }
261
262    pub fn with_temperature(mut self, temperature: f64) -> Self {
263        self.temperature = Some(temperature);
264        self
265    }
266
267    pub fn with_n(mut self, n: u64) -> Self {
268        self.n = Some(n);
269        self
270    }
271
272    pub fn with_stream(mut self, stream: bool) -> Self {
273        self.stream = Some(stream);
274        self
275    }
276
277    pub fn with_stop(mut self, rhs: Stop) -> Self {
278        self.stop = Some(rhs);
279        self
280    }
281
282    pub fn add_stop(mut self, rhs: Stop) -> Self {
283        let mut lhs = None;
284        std::mem::swap(&mut self.stop, &mut lhs);
285        self.stop = match lhs {
286            None => Some(rhs),
287            Some(lhs) => Some(lhs.append(rhs)),
288        };
289        self
290    }
291
292    pub fn with_frequency_penalty(mut self, frequency_penalty: f64) -> Self {
293        self.frequency_penalty = Some(frequency_penalty);
294        self
295    }
296
297    pub fn build(self) -> Result<ChatCompletionRequest> {
298        let Self {
299            model,
300            messages,
301            tools,
302            max_tokens,
303            temperature,
304            top_p,
305            n,
306            stream,
307            stop,
308            frequency_penalty,
309            response_format,
310        } = self;
311
312        let model = model.ok_or(Error::ChatCompletionRequestBuild)?;
313
314        if messages.is_empty() {
315            return Err(Error::ChatCompletionRequestBuild);
316        }
317
318        let r = ChatCompletionRequest {
319            model,
320            messages,
321            tools,
322            max_tokens,
323            temperature,
324            top_p,
325            n,
326            stream,
327            stop,
328            frequency_penalty,
329            response_format,
330        };
331
332        for l in serde_json::to_string_pretty(&r)?.lines() {
333            trace!("REQ: {}", l);
334        }
335
336        Ok(r)
337    }
338}
339
340impl ChatCompletionRequest {
341    pub fn builder() -> ChatCompletionRequestBuilder {
342        ChatCompletionRequestBuilder::default()
343    }
344}
345
346#[skip_serializing_none]
347#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, SmartDefault)]
348pub struct ChatCompletionResponse {
349    pub id: String,
350    #[default("chat.completion".to_string())]
351    pub object: String,
352    pub created: u64,
353    pub model: String,
354    #[serde(default)]
355    pub choices: Vec<Choice>,
356    pub usage: Option<ChatComplitionUsage>,
357}
358
359impl ChatCompletionResponse {
360    pub fn merge_delta(&mut self, delta: ChatCompletionStreamData) {
361        let ChatCompletionStreamData {
362            id,
363            object,
364            created,
365            model,
366            choices,
367            usage,
368        } = delta;
369
370        if let Some(usage) = usage {
371            self.usage = Some(usage);
372        }
373
374        if let Some(id) = id {
375            self.id = id;
376        }
377
378        if let Some(object) = object {
379            if self.object.is_empty() {
380                self.object = object;
381            }
382        }
383
384        if let Some(created) = created {
385            self.created = created;
386        }
387
388        if let Some(model) = model {
389            self.model = model;
390        }
391
392        'outer: for delta in choices {
393            let StreamChoice {
394                index,
395                delta,
396                finish_reason,
397                usage,
398            } = delta;
399
400            if let Some(usage) = usage {
401                self.usage = Some(usage);
402            }
403
404            let Message {
405                role,
406                content,
407                tool_calls,
408                tool_call_id,
409            } = delta;
410
411            for choice in &mut self.choices {
412                if choice.index == index {
413                    if let Some(role) = role {
414                        choice.message.role = Some(role);
415                    }
416
417                    if let Some(delta_content) = content {
418                        let mut choice_content = None;
419                        std::mem::swap(&mut choice.message.content, &mut choice_content);
420                        match choice_content.as_mut() {
421                            Some(c) => c.merge(delta_content),
422                            None => choice_content = Some(delta_content),
423                        };
424                        std::mem::swap(&mut choice.message.content, &mut choice_content);
425                    }
426
427                    if let Some(tool_call_id) = tool_call_id {
428                        choice.message.tool_call_id = Some(tool_call_id);
429                    }
430
431                    if choice.message.tool_calls.is_empty() {
432                        choice.message.tool_calls = tool_calls;
433                    } else {
434                        choice
435                            .message
436                            .tool_calls
437                            .iter_mut()
438                            .zip(tool_calls)
439                            .for_each(|(lhs, rhs)| {
440                                if let Some(name) = rhs.function.name.as_ref() {
441                                    if !name.is_empty() {
442                                        lhs.function.name = Some(name.clone());
443                                    }
444                                }
445
446                                match (&mut lhs.function.arguments, &rhs.function.arguments) {
447                                    (Some(lhs), Some(rhs)) => {
448                                        *lhs = format!("{}{}", lhs, rhs);
449                                    }
450                                    (None, Some(rhs)) => {
451                                        lhs.function.arguments = Some(rhs.clone());
452                                    }
453                                    _ => {}
454                                }
455                            });
456                    }
457
458                    if let Some(finish_reason) = finish_reason {
459                        choice.finish_reason = Some(finish_reason);
460                    }
461
462                    continue 'outer;
463                }
464            }
465
466            self.choices.push(Choice {
467                index,
468                message: Message {
469                    role,
470                    content,
471                    tool_call_id,
472                    tool_calls,
473                },
474                finish_reason,
475            });
476        }
477    }
478}
479
480#[skip_serializing_none]
481#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
482pub struct Choice {
483    pub index: usize,
484    pub message: Message,
485    pub finish_reason: Option<String>,
486}
487
488#[skip_serializing_none]
489#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
490pub enum Stop {
491    Text(String),
492    Texts(Vec<String>),
493}
494
495impl Stop {
496    pub fn append(self, rhs: Stop) -> Self {
497        match (self, rhs) {
498            (Stop::Text(lhs), Stop::Text(rhs)) => Stop::Texts(vec![lhs, rhs]),
499            (Stop::Text(lhs), Stop::Texts(mut rhs)) => {
500                rhs.push(lhs);
501                Stop::Texts(rhs)
502            }
503            (Stop::Texts(mut lhs), Stop::Text(rhs)) => {
504                lhs.push(rhs);
505                Stop::Texts(lhs)
506            }
507            (Stop::Texts(mut lhs), Stop::Texts(rhs)) => {
508                lhs.extend(rhs);
509                Stop::Texts(lhs)
510            }
511        }
512    }
513}
514
515#[skip_serializing_none]
516#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, SmartDefault)]
517pub struct ChatComplitionUsage {
518    pub cached_tokens: Option<u64>,
519    pub completion_tokens: u64,
520    pub prompt_tokens: u64,
521    pub total_tokens: u64,
522}
523
524#[skip_serializing_none]
525#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, SmartDefault)]
526pub struct Message {
527    #[serde(default, deserialize_with = "empty_string_as_none")]
528    pub role: Option<Role>,
529    pub content: Option<Content>,
530    pub tool_call_id: Option<String>,
531    #[serde(default)]
532    pub tool_calls: Vec<ToolCall>,
533}
534
535fn empty_string_as_none<'de, D>(de: D) -> std::result::Result<Option<Role>, D::Error>
536where
537    D: serde::Deserializer<'de>,
538{
539    let opt = Option::<String>::deserialize(de)?;
540    match opt.as_deref() {
541        None | Some("") => Ok(None),
542        Some(s) => Role::deserialize(s.into_deserializer()).map(Some),
543    }
544}
545
546impl Message {
547    pub fn builder() -> MessageBuilder {
548        MessageBuilder::default()
549    }
550}
551
552#[derive(SmartDefault)]
553pub struct MessageBuilder {
554    role: Option<Role>,
555    content: Option<Content>,
556    tool_call_id: Option<String>,
557    tool_calls: Vec<ToolCall>,
558}
559
560impl MessageBuilder {
561    pub fn with_role(mut self, role: Role) -> Self {
562        self.role = Some(role);
563        self
564    }
565
566    pub fn with_content(mut self, content: impl Into<Content>) -> Self {
567        self.content = Some(content.into());
568        self
569    }
570
571    // pub fn add_content(mut self, content: impl Into<Content>) -> Self {
572    //     let mut lhs = None;
573    //     std::mem::swap(&mut self.content, &mut lhs);
574    //     self.content = match lhs {
575    //         None => Some(content.into()),
576    //         Some(lhs) => Some(lhs.merge(content)),
577    //     };
578    //     self
579    // }
580
581    pub fn with_tool_call_id(mut self, tool_call_id: impl Into<String>) -> Self {
582        self.tool_call_id = Some(tool_call_id.into());
583        self
584    }
585
586    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
587        self.tool_calls = tool_calls;
588        self
589    }
590
591    pub fn add_tool_call(mut self, tool_call: ToolCall) -> Self {
592        self.tool_calls.push(tool_call);
593        self
594    }
595
596    pub fn build(self) -> Message {
597        let Self {
598            role,
599            content,
600            tool_call_id,
601            tool_calls,
602        } = self;
603
604        Message {
605            role,
606            content,
607            tool_call_id,
608            tool_calls,
609        }
610    }
611}
612
613#[skip_serializing_none]
614#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
615#[allow(non_camel_case_types)]
616pub enum Role {
617    system,
618    user,
619    assistant,
620    tool,
621}
622
623#[skip_serializing_none]
624#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
625#[serde(untagged)]
626pub enum Content {
627    Text(String),
628    Containers(Vec<ContentContainer>),
629}
630
631impl From<String> for Content {
632    fn from(s: String) -> Self {
633        Content::Text(s)
634    }
635}
636
637impl From<&str> for Content {
638    fn from(s: &str) -> Self {
639        Content::Text(s.to_string())
640    }
641}
642
643impl From<ImageUrl> for Content {
644    fn from(url: ImageUrl) -> Self {
645        Content::Containers(vec![ContentContainer::Image {
646            typ: "image_url".into(),
647            image_url: url,
648        }])
649    }
650}
651
652impl Content {
653    pub fn from_image_url(url: &str) -> Self {
654        Content::Containers(vec![ContentContainer::Image {
655            typ: "image_url".into(),
656            image_url: ImageUrl::from_url(url),
657        }])
658    }
659
660    pub fn from_text(text: impl Into<String>) -> Self {
661        Content::Text(text.into())
662    }
663
664    pub fn merge(&mut self, rhs: Self) {
665        *self = match self {
666            Content::Text(s0) => match rhs {
667                Content::Text(s1) => {
668                    *s0 += s1.as_str();
669                    return;
670                }
671                Content::Containers(cs) => {
672                    let mut cs_ = vec![s0.clone().into()];
673                    cs_.extend(cs);
674                    Content::Containers(cs_)
675                }
676            },
677            Content::Containers(cs) => {
678                match rhs {
679                    Content::Text(s1) => cs.push(ContentContainer::Text {
680                        typ: "text".into(),
681                        text: s1,
682                    }),
683                    Content::Containers(cs_) => cs.extend(cs_),
684                }
685                return;
686            }
687        };
688    }
689
690    pub fn append(&mut self, item: impl Into<ContentContainer>) {
691        *self = match self {
692            Content::Text(s) => Content::Containers(vec![
693                ContentContainer::Text {
694                    typ: "text".into(),
695                    text: s.clone(),
696                },
697                item.into(),
698            ]),
699            Content::Containers(cs) => {
700                let mut cs_ = vec![];
701                std::mem::swap(cs, &mut cs_);
702                cs_.push(item.into());
703                Content::Containers(cs_)
704            }
705        };
706    }
707}
708
709#[skip_serializing_none]
710#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
711#[serde(untagged)]
712pub enum ContentContainer {
713    Text {
714        #[serde(rename = "type")]
715        typ: String,
716        text: String,
717    },
718    Image {
719        #[serde(rename = "type")]
720        typ: String,
721        image_url: ImageUrl,
722    },
723}
724
725#[skip_serializing_none]
726#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
727pub struct ImageUrl {
728    pub url: String,
729}
730
731impl From<ImageUrl> for ContentContainer {
732    fn from(url: ImageUrl) -> Self {
733        ContentContainer::Image {
734            typ: "image_url".into(),
735            image_url: url,
736        }
737    }
738}
739
740impl From<String> for ContentContainer {
741    fn from(s: String) -> Self {
742        ContentContainer::Text {
743            typ: "text".into(),
744            text: s,
745        }
746    }
747}
748
749impl ImageUrl {
750    pub async fn from_local_file(path: impl Into<std::path::PathBuf>) -> Result<Self> {
751        let path = path.into();
752        let suffix = path
753            .extension()
754            .ok_or(Error::NoFileExtension)?
755            .to_str()
756            .ok_or(Error::NoFileExtension)?;
757        let binary = tokio::fs::read(&path).await?;
758        Ok(ImageUrl::from_image_binary(binary, suffix))
759    }
760
761    pub fn from_url(url: impl Into<String>) -> Self {
762        ImageUrl { url: url.into() }
763    }
764
765    pub fn from_image_binary(image: impl AsRef<[u8]>, suffix: impl AsRef<str>) -> Self {
766        ImageUrl {
767            url: format!(
768                "data:image/{};base64,{}",
769                suffix.as_ref(),
770                base64::prelude::BASE64_STANDARD.encode(image)
771            ),
772        }
773    }
774}
775
776#[skip_serializing_none]
777#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
778pub struct ChatCompletionStreamData {
779    pub id: Option<String>,
780    pub object: Option<String>,
781    pub created: Option<u64>,
782    pub model: Option<String>,
783    pub choices: Vec<StreamChoice>,
784    pub usage: Option<ChatComplitionUsage>,
785}
786
787#[skip_serializing_none]
788#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
789pub struct StreamChoice {
790    pub index: usize,
791    pub delta: Message,
792    pub finish_reason: Option<String>,
793    pub usage: Option<ChatComplitionUsage>,
794}
795
796#[allow(dead_code)]
797#[cfg(test)]
798async fn example_code() -> Result<()> {
799    // build a client
800    let client = Client::builder()
801        .with_base_url("https://api.stepfun.com")?
802        .with_key("you api key")?
803        .with_version("v1")?
804        .build()?;
805
806    // build a request
807    let req = ChatCompletionRequest::builder()
808        .with_model("step-1-8k")
809        .with_messages([
810            Message::builder()
811                .with_role(Role::system)
812                .with_content("you are a good llm model")
813                .build(),
814            Message::builder()
815                .with_role(Role::user)
816                .with_content("calculate 1921.23 + 42.00")
817                .build(),
818        ])
819        .with_tools([Function::builder()
820            .with_name("add_number")
821            .with_description("add two numbers")
822            .with_parameters(
823                Parameters::builder()
824                    .add_property(
825                        "a",
826                        ParameterProperty::builder()
827                            .with_description("number 1 in 2 numbers")
828                            .with_type(ParameterType::number)
829                            .build()?,
830                        true,
831                    )
832                    .add_property(
833                        "b",
834                        ParameterProperty::builder()
835                            .with_description("number 2 in 2 numbers")
836                            .with_type(ParameterType::number)
837                            .build()?,
838                        true,
839                    )
840                    .build()?,
841            )
842            .build()?])
843        .with_stream(false) // if true, the response will be a stream
844        .build()?;
845
846    // call request
847    let res = req.call(&client, None).await?;
848
849    // base on with_stream, the rep will be different
850    let rep = match res {
851        // will return result at once
852        ChatCompletionResult::Response(rep) => rep,
853        // will return a async receiver of ChatCompletionStreamData
854        ChatCompletionResult::Delta(mut rx) => {
855            let mut rep_total = ChatCompletionResponse::default();
856            while let Some(res) = rx.recv().await {
857                match res {
858                    Ok(rep) => {
859                        rep_total.merge_delta(rep);
860                    }
861                    Err(e) => {
862                        error!("failed to recv rep: {:?}", e);
863                        break;
864                    }
865                }
866            }
867            rep_total
868        }
869    };
870
871    // log and print result
872    for l in serde_json::to_string_pretty(&rep)?.lines() {
873        info!("FINAL REP: {}", l);
874    }
875
876    Ok(())
877}
878
879#[cfg(test)]
880#[tokio::test]
881async fn test_chat_simple_ok() -> Result<()> {
882    let client = Client::from_env_file(".env.stepfun")?;
883
884    let model_name = std::env::var("OPENAI_API_MODEL_NAME")?;
885    let use_stream = std::env::var("USE_STREAM").is_ok();
886
887    let _ = tracing_subscriber::fmt::try_init();
888
889    let req = ChatCompletionRequest::builder()
890        .with_model("step-1-8k")
891        .with_messages([
892            Message::builder()
893                .with_role(Role::system)
894                .with_content("you are a good llm model")
895                .build(),
896            Message::builder()
897                .with_role(Role::user)
898                .with_content("calculate 1921.23 + 42.00")
899                .build(),
900        ])
901        .with_tools([Function::builder()
902            .with_name("add_number")
903            .with_description("add two numbers")
904            .with_parameters(
905                Parameters::builder()
906                    .add_property(
907                        "a",
908                        ParameterProperty::builder()
909                            .with_description("number 1 in 2 numbers")
910                            .with_type(ParameterType::number)
911                            .build()?,
912                        true,
913                    )
914                    .add_property(
915                        "b",
916                        ParameterProperty::builder()
917                            .with_description("number 2 in 2 numbers")
918                            .with_type(ParameterType::number)
919                            .build()?,
920                        true,
921                    )
922                    .build()?,
923            )
924            .build()?])
925        .with_stream(false) // if true, the response will be a stream
926        .build()?;
927
928    let res = req.call(&client, None).await?;
929
930    let rep = match res {
931        ChatCompletionResult::Response(rep) => rep,
932        ChatCompletionResult::Delta(mut rx) => {
933            let mut rep_total = ChatCompletionResponse::default();
934            while let Some(res) = rx.recv().await {
935                match res {
936                    Ok(rep) => {
937                        rep_total.merge_delta(rep);
938                    }
939                    Err(e) => {
940                        error!("failed to recv rep: {:?}", e);
941                        break;
942                    }
943                }
944            }
945            rep_total
946        }
947    };
948
949    for l in serde_json::to_string_pretty(&rep)?.lines() {
950        info!("FINAL REP: {}", l);
951    }
952
953    Ok(())
954}