Skip to main content

openai_models/
llm.rs

1use std::{
2    fmt::{Debug, Display},
3    ops::{Deref, DerefMut},
4    path::{Path, PathBuf},
5    str::FromStr,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10    time::Duration,
11};
12
13use async_openai::{
14    Client,
15    config::{AzureConfig, OpenAIConfig},
16    error::OpenAIError,
17    types::chat::{
18        ChatChoice, ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
19        ChatCompletionNamedToolChoiceCustom, ChatCompletionRequestAssistantMessageContent,
20        ChatCompletionRequestAssistantMessageContentPart,
21        ChatCompletionRequestDeveloperMessageContent,
22        ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24        ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25        ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27        ChatCompletionResponseMessage, ChatCompletionResponseStream, ChatCompletionStreamOptions,
28        ChatCompletionToolChoiceOption, ChatCompletionTools, CompletionUsage,
29        CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
30        CreateChatCompletionStreamResponse, CustomName, FinishReason, FunctionCall, Role,
31        ToolChoiceOptions,
32    },
33};
34use clap::Args;
35use color_eyre::{
36    Result,
37    eyre::{OptionExt, eyre},
38};
39use futures_util::StreamExt;
40use itertools::Itertools;
41use log::{debug, info, trace, warn};
42use serde::{Deserialize, Serialize};
43use tokio::{io::AsyncWriteExt, sync::RwLock};
44
45use crate::{OpenAIModel, error::PromptError};
46
47#[derive(Clone, Debug, Default)]
48struct ToolCallAcc {
49    id: String,
50    name: String,
51    arguments: String,
52}
53
54// Upstream implementation is flawed
55#[derive(Debug, Clone)]
56pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
57
58impl FromStr for LLMToolChoice {
59    type Err = PromptError;
60    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
61        Ok(match s {
62            "auto" => Self(ChatCompletionToolChoiceOption::Mode(
63                ToolChoiceOptions::Auto,
64            )),
65            "required" => Self(ChatCompletionToolChoiceOption::Mode(
66                ToolChoiceOptions::Required,
67            )),
68            "none" => Self(ChatCompletionToolChoiceOption::Mode(
69                ToolChoiceOptions::None,
70            )),
71            _ => Self(ChatCompletionToolChoiceOption::Custom(
72                ChatCompletionNamedToolChoiceCustom {
73                    custom: CustomName {
74                        name: s.to_string(),
75                    },
76                },
77            )),
78        })
79    }
80}
81
82impl Deref for LLMToolChoice {
83    type Target = ChatCompletionToolChoiceOption;
84    fn deref(&self) -> &Self::Target {
85        &self.0
86    }
87}
88
89impl DerefMut for LLMToolChoice {
90    fn deref_mut(&mut self) -> &mut Self::Target {
91        &mut self.0
92    }
93}
94
95impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
96    fn from(value: ChatCompletionToolChoiceOption) -> Self {
97        Self(value)
98    }
99}
100
101impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
102    fn from(value: LLMToolChoice) -> Self {
103        value.0
104    }
105}
106
107#[derive(Args, Clone, Debug)]
108pub struct LLMSettings {
109    #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
110    pub llm_temperature: f32,
111
112    #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
113    pub llm_presence_penalty: f32,
114
115    #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
116    pub llm_prompt_timeout: u64,
117
118    #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
119    pub llm_retry: u64,
120
121    #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
122    pub llm_max_completion_tokens: u32,
123
124    #[arg(long, env = "LLM_TOOL_CHOINCE")]
125    pub llm_tool_choice: Option<LLMToolChoice>,
126
127    #[arg(
128        long,
129        env = "LLM_STREAM",
130        default_value_t = false,
131        value_parser = clap::builder::BoolishValueParser::new()
132    )]
133    pub llm_stream: bool,
134}
135
136#[derive(Args, Clone, Debug)]
137pub struct OpenAISetup {
138    #[arg(
139        long,
140        env = "OPENAI_API_URL",
141        default_value = "https://api.openai.com/v1"
142    )]
143    pub openai_url: String,
144
145    #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
146    pub azure_openai_endpoint: Option<String>,
147
148    #[arg(long, env = "OPENAI_API_KEY")]
149    pub openai_key: Option<String>,
150
151    #[arg(long, env = "AZURE_API_DEPLOYMENT")]
152    pub azure_deployment: Option<String>,
153
154    #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
155    pub azure_api_version: String,
156
157    #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
158    pub biling_cap: f64,
159
160    #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
161    pub model: OpenAIModel,
162
163    #[arg(long, env = "LLM_DEBUG")]
164    pub llm_debug: Option<PathBuf>,
165
166    #[clap(flatten)]
167    pub llm_settings: LLMSettings,
168}
169
170impl OpenAISetup {
171    pub fn to_config(&self) -> SupportedConfig {
172        if let Some(ep) = self.azure_openai_endpoint.as_ref() {
173            let cfg = AzureConfig::new()
174                .with_api_base(ep)
175                .with_api_key(self.openai_key.clone().unwrap_or_default())
176                .with_deployment_id(
177                    self.azure_deployment
178                        .as_ref()
179                        .unwrap_or(&self.model.to_string()),
180                )
181                .with_api_version(&self.azure_api_version);
182            SupportedConfig::Azure(cfg)
183        } else {
184            let cfg = OpenAIConfig::new()
185                .with_api_base(&self.openai_url)
186                .with_api_key(self.openai_key.clone().unwrap_or_default());
187            SupportedConfig::OpenAI(cfg)
188        }
189    }
190
191    pub fn to_llm(&self) -> LLM {
192        let billing = RwLock::new(ModelBilling::new(self.biling_cap));
193
194        let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
195            let pid = std::process::id();
196
197            let mut cnt = 0u64;
198            let debug_path;
199            loop {
200                let test_path = dbg.join(format!("{}-{}", pid, cnt));
201                if !test_path.exists() {
202                    std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
203                    debug_path = Some(test_path);
204                    debug!("The path to save LLM interactions is {:?}", &debug_path);
205                    break;
206                } else {
207                    cnt += 1;
208                }
209            }
210            debug_path
211        } else {
212            None
213        };
214
215        LLM {
216            llm: Arc::new(LLMInner {
217                client: LLMClient::new(self.to_config()),
218                model: self.model.clone(),
219                billing,
220                llm_debug: debug_path,
221                llm_debug_index: AtomicU64::new(0),
222                default_settings: self.llm_settings.clone(),
223            }),
224        }
225    }
226}
227
228#[derive(Debug, Clone)]
229pub enum SupportedConfig {
230    Azure(AzureConfig),
231    OpenAI(OpenAIConfig),
232}
233
234#[derive(Debug, Clone)]
235pub enum LLMClient {
236    Azure(Client<AzureConfig>),
237    OpenAI(Client<OpenAIConfig>),
238}
239
240impl LLMClient {
241    pub fn new(config: SupportedConfig) -> Self {
242        match config {
243            SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
244            SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
245        }
246    }
247
248    pub async fn create_chat(
249        &self,
250        req: CreateChatCompletionRequest,
251    ) -> Result<CreateChatCompletionResponse, OpenAIError> {
252        match self {
253            Self::Azure(cl) => cl.chat().create(req).await,
254            Self::OpenAI(cl) => cl.chat().create(req).await,
255        }
256    }
257
258    pub async fn create_chat_stream(
259        &self,
260        req: CreateChatCompletionRequest,
261    ) -> Result<ChatCompletionResponseStream, OpenAIError> {
262        match self {
263            Self::Azure(cl) => cl.chat().create_stream(req).await,
264            Self::OpenAI(cl) => cl.chat().create_stream(req).await,
265        }
266    }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ModelBilling {
271    pub current: f64,
272    pub cap: f64,
273}
274
275impl Display for ModelBilling {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
278    }
279}
280
281impl ModelBilling {
282    pub fn new(cap: f64) -> Self {
283        Self { current: 0.0, cap }
284    }
285
286    pub fn in_cap(&self) -> bool {
287        self.current <= self.cap
288    }
289
290    pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
291        let pricing = model.pricing();
292
293        self.current += (pricing.input_tokens * (count as f64)) / 1e6;
294
295        if self.in_cap() {
296            Ok(())
297        } else {
298            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
299        }
300    }
301
302    pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
303        let pricing = model.pricing();
304
305        self.current += pricing.output_tokens * (count as f64) / 1e6;
306
307        if self.in_cap() {
308            Ok(())
309        } else {
310            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
311        }
312    }
313}
314
315#[derive(Debug, Clone)]
316pub struct LLM {
317    pub llm: Arc<LLMInner>,
318}
319
320impl Deref for LLM {
321    type Target = LLMInner;
322
323    fn deref(&self) -> &Self::Target {
324        &self.llm
325    }
326}
327
328#[derive(Debug)]
329pub struct LLMInner {
330    pub client: LLMClient,
331    pub model: OpenAIModel,
332    pub billing: RwLock<ModelBilling>,
333    pub llm_debug: Option<PathBuf>,
334    pub llm_debug_index: AtomicU64,
335    pub default_settings: LLMSettings,
336}
337
338pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
339    match msg {
340        ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
341        ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
342        ChatCompletionRequestMessage::Function(_) => "FUNCTION",
343        ChatCompletionRequestMessage::System(_) => "SYSTEM",
344        ChatCompletionRequestMessage::Tool(_) => "TOOL",
345        ChatCompletionRequestMessage::User(_) => "USER",
346    }
347}
348
349pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
350    match t {
351        ChatCompletionMessageToolCalls::Function(t) => {
352            format!(
353                "<toolcall name=\"{}\">\n{}\n</toolcall>",
354                &t.function.name, &t.function.arguments
355            )
356        }
357        ChatCompletionMessageToolCalls::Custom(t) => {
358            format!(
359                "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
360                &t.custom_tool.name, &t.custom_tool.input
361            )
362        }
363    }
364}
365
366pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
367    let mut s = String::new();
368    if let Some(content) = resp.content.as_ref() {
369        s += content;
370        s += "\n";
371    }
372
373    if let Some(tools) = resp.tool_calls.as_ref() {
374        s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
375    }
376
377    if let Some(refusal) = &resp.refusal {
378        s += refusal;
379        s += "\n";
380    }
381
382    let role = resp.role.to_string().to_uppercase();
383
384    format!("<{}>\n{}\n</{}>\n", &role, s, &role)
385}
386
387pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
388    const CONT: &str = "<cont/>\n";
389    const NONE: &str = "<none/>\n";
390    let role = completion_to_role(msg);
391    let content = match msg {
392        ChatCompletionRequestMessage::Assistant(ass) => {
393            let msg = ass
394                .content
395                .as_ref()
396                .map(|ass| match ass {
397                    ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
398                    ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
399                        .iter()
400                        .map(|v| match v {
401                            ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
402                                s.text.clone()
403                            }
404                            ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
405                                rf.refusal.clone()
406                            }
407                        })
408                        .join(CONT),
409                })
410                .unwrap_or(NONE.to_string());
411            let tool_calls = ass
412                .tool_calls
413                .iter()
414                .flatten()
415                .map(|t| toolcall_to_string(t))
416                .join("\n");
417            format!("{}\n{}", msg, tool_calls)
418        }
419        ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
420            ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
421            ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
422                .iter()
423                .map(|v| match v {
424                    ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
425                })
426                .join(CONT),
427        },
428        ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
429        ChatCompletionRequestMessage::System(sys) => match &sys.content {
430            ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
431            ChatCompletionRequestSystemMessageContent::Array(arr) => arr
432                .iter()
433                .map(|v| match v {
434                    ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
435                })
436                .join(CONT),
437        },
438        ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
439            ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
440            ChatCompletionRequestToolMessageContent::Array(arr) => arr
441                .iter()
442                .map(|v| match v {
443                    ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
444                })
445                .join(CONT),
446        },
447        ChatCompletionRequestMessage::User(usr) => match &usr.content {
448            ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
449            ChatCompletionRequestUserMessageContent::Array(arr) => arr
450                .iter()
451                .map(|v| match v {
452                    ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
453                    ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
454                        format!("<img url=\"{}\"/>", &img.image_url.url)
455                    }
456                    ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
457                        format!("<audio>{}</audio>", audio.input_audio.data)
458                    }
459                    ChatCompletionRequestUserMessageContentPart::File(f) => {
460                        format!("<file>{:?}</file>", f)
461                    }
462                })
463                .join(CONT),
464        },
465    };
466
467    format!("<{}>\n{}\n</{}>\n", role, content, role)
468}
469
470impl LLMInner {
471    async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
472        let mut json_fp = fpath.to_path_buf();
473        json_fp.set_file_name(format!(
474            "{}.json",
475            json_fp
476                .file_stem()
477                .ok_or_eyre(eyre!("no filename"))?
478                .to_str()
479                .ok_or_eyre(eyre!("non-utf fname"))?
480        ));
481
482        let mut fp = tokio::fs::OpenOptions::new()
483            .create(true)
484            .append(true)
485            .write(true)
486            .open(&json_fp)
487            .await?;
488        let s = match serde_json::to_string(&t) {
489            Ok(s) => s,
490            Err(_) => format!("{:?}", &t),
491        };
492        fp.write_all(s.as_bytes()).await?;
493        fp.write_all(b"\n").await?;
494        fp.flush().await?;
495
496        Ok(())
497    }
498
499    async fn save_llm_user(
500        fpath: &PathBuf,
501        user_msg: &CreateChatCompletionRequest,
502    ) -> Result<(), PromptError> {
503        let mut fp = tokio::fs::OpenOptions::new()
504            .create(true)
505            .truncate(true)
506            .write(true)
507            .open(&fpath)
508            .await?;
509        fp.write_all(b"=====================\n<Request>\n").await?;
510        for it in user_msg.messages.iter() {
511            let msg = completion_to_string(it);
512            fp.write_all(msg.as_bytes()).await?;
513        }
514
515        let mut tools = vec![];
516        for tool in user_msg
517            .tools
518            .as_ref()
519            .map(|t| t.iter())
520            .into_iter()
521            .flatten()
522        {
523            let s = match tool {
524                ChatCompletionTools::Function(tool) => {
525                    format!(
526                        "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
527                        &tool.function.name,
528                        &tool.function.description.clone().unwrap_or_default(),
529                        tool.function.strict.unwrap_or_default(),
530                        tool.function
531                            .parameters
532                            .as_ref()
533                            .map(serde_json::to_string_pretty)
534                            .transpose()?
535                            .unwrap_or_default()
536                    )
537                }
538                ChatCompletionTools::Custom(tool) => {
539                    format!(
540                        "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
541                        tool.custom.name, tool.custom.description
542                    )
543                }
544            };
545            tools.push(s);
546        }
547        fp.write_all(tools.join("\n").as_bytes()).await?;
548        fp.write_all(b"\n</Request>\n=====================\n")
549            .await?;
550        fp.flush().await?;
551
552        Self::rewrite_json(fpath, user_msg).await?;
553
554        Ok(())
555    }
556
557    async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
558        let mut fp = tokio::fs::OpenOptions::new()
559            .create(false)
560            .append(true)
561            .write(true)
562            .open(&fpath)
563            .await?;
564        fp.write_all(b"=====================\n<Response>\n").await?;
565        for it in &resp.choices {
566            let msg = response_to_string(&it.message);
567            fp.write_all(msg.as_bytes()).await?;
568        }
569        fp.write_all(b"\n</Response>\n=====================\n")
570            .await?;
571        fp.flush().await?;
572
573        Self::rewrite_json(fpath, resp).await?;
574
575        Ok(())
576    }
577
578    fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
579        if let Some(output_folder) = self.llm_debug.as_ref() {
580            let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
581            let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
582            Some(fpath)
583        } else {
584            None
585        }
586    }
587
588    // we use t/s to estimate a timeout to avoid infinite repeating
589    pub async fn prompt_once_with_retry(
590        &self,
591        sys_msg: &str,
592        user_msg: &str,
593        prefix: Option<&str>,
594        settings: Option<LLMSettings>,
595    ) -> Result<CreateChatCompletionResponse, PromptError> {
596        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
597        let sys = ChatCompletionRequestSystemMessageArgs::default()
598            .content(sys_msg)
599            .build()?;
600
601        let user = ChatCompletionRequestUserMessageArgs::default()
602            .content(user_msg)
603            .build()?;
604        let mut req = CreateChatCompletionRequestArgs::default();
605        req.messages(vec![sys.into(), user.into()])
606            .model(self.model.to_string())
607            .temperature(settings.llm_temperature)
608            .presence_penalty(settings.llm_presence_penalty)
609            .max_completion_tokens(settings.llm_max_completion_tokens);
610
611        if let Some(tc) = settings.llm_tool_choice {
612            req.tool_choice(tc);
613        }
614        let req = req.build()?;
615
616        let timeout = if settings.llm_prompt_timeout == 0 {
617            Duration::MAX
618        } else {
619            Duration::from_secs(settings.llm_prompt_timeout)
620        };
621
622        self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
623            .await
624    }
625
626    pub async fn complete_once_with_retry(
627        &self,
628        req: &CreateChatCompletionRequest,
629        prefix: Option<&str>,
630        timeout: Option<Duration>,
631        retry: Option<u64>,
632    ) -> Result<CreateChatCompletionResponse, PromptError> {
633        let timeout = if let Some(timeout) = timeout {
634            timeout
635        } else {
636            Duration::MAX
637        };
638
639        let retry = if let Some(retry) = retry {
640            retry
641        } else {
642            u64::MAX
643        };
644
645        let mut last = None;
646        for idx in 0..retry {
647            match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
648                Ok(r) => {
649                    last = Some(r);
650                }
651                Err(_) => {
652                    warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
653                    continue;
654                }
655            };
656
657            match last {
658                Some(Ok(r)) => return Ok(r),
659                Some(Err(ref e)) => {
660                    warn!(
661                        "Having an error {} during {} retry (timeout is {:?})",
662                        e, idx, timeout
663                    );
664                }
665                _ => {}
666            }
667        }
668
669        last.ok_or_eyre(eyre!("retry is zero?!"))
670            .map_err(PromptError::Other)?
671    }
672
673    pub async fn complete(
674        &self,
675        req: CreateChatCompletionRequest,
676        prefix: Option<&str>,
677    ) -> Result<CreateChatCompletionResponse, PromptError> {
678        let use_stream = self.default_settings.llm_stream;
679        let prefix = if let Some(prefix) = prefix {
680            prefix.to_string()
681        } else {
682            "llm".to_string()
683        };
684        let debug_fp = self.on_llm_debug(&prefix);
685
686        if let Some(debug_fp) = debug_fp.as_ref() {
687            if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
688                warn!("Fail to save user due to {}", e);
689            }
690        }
691
692        trace!(
693            "Sending completion request: {:?}",
694            &serde_json::to_string(&req)
695        );
696        let resp = if use_stream {
697            self.complete_streaming(req).await?
698        } else {
699            self.client.create_chat(req).await?
700        };
701
702        if let Some(debug_fp) = debug_fp.as_ref() {
703            if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
704                warn!("Fail to save resp due to {}", e);
705            }
706        }
707
708        if let Some(usage) = &resp.usage {
709            self.billing
710                .write()
711                .await
712                .input_tokens(&self.model, usage.prompt_tokens as u64)
713                .map_err(PromptError::Other)?;
714            self.billing
715                .write()
716                .await
717                .output_tokens(&self.model, usage.completion_tokens as u64)
718                .map_err(PromptError::Other)?;
719        } else {
720            warn!("No usage?!")
721        }
722
723        info!("Model Billing: {}", &self.billing.read().await);
724        Ok(resp)
725    }
726
727    async fn complete_streaming(
728        &self,
729        mut req: CreateChatCompletionRequest,
730    ) -> Result<CreateChatCompletionResponse, PromptError> {
731        if req.stream_options.is_none() {
732            req.stream_options = Some(ChatCompletionStreamOptions {
733                include_usage: Some(true),
734                include_obfuscation: None,
735            });
736        }
737
738        let mut stream = self.client.create_chat_stream(req).await?;
739
740        let mut id: Option<String> = None;
741        let mut created: Option<u32> = None;
742        let mut model: Option<String> = None;
743        let mut service_tier = None;
744        let mut system_fingerprint = None;
745        let mut usage: Option<CompletionUsage> = None;
746
747        let mut contents: Vec<String> = Vec::new();
748        let mut finish_reasons: Vec<Option<FinishReason>> = Vec::new();
749        let mut tool_calls: Vec<Vec<ToolCallAcc>> = Vec::new();
750
751        while let Some(item) = stream.next().await {
752            let chunk: CreateChatCompletionStreamResponse = item?;
753            if id.is_none() {
754                id = Some(chunk.id.clone());
755            }
756            created = Some(chunk.created);
757            model = Some(chunk.model.clone());
758            service_tier = chunk.service_tier.clone();
759            system_fingerprint = chunk.system_fingerprint.clone();
760            if let Some(u) = chunk.usage.clone() {
761                usage = Some(u);
762            }
763
764            for ch in chunk.choices.into_iter() {
765                let idx = ch.index as usize;
766                if contents.len() <= idx {
767                    contents.resize_with(idx + 1, String::new);
768                    finish_reasons.resize_with(idx + 1, || None);
769                    tool_calls.resize_with(idx + 1, Vec::new);
770                }
771                if let Some(delta) = ch.delta.content {
772                    contents[idx].push_str(&delta);
773                }
774                if let Some(tcs) = ch.delta.tool_calls {
775                    for tc in tcs.into_iter() {
776                        let tc_idx = tc.index as usize;
777                        if tool_calls[idx].len() <= tc_idx {
778                            tool_calls[idx].resize_with(tc_idx + 1, ToolCallAcc::default);
779                        }
780                        let acc = &mut tool_calls[idx][tc_idx];
781                        if let Some(id) = tc.id {
782                            acc.id = id;
783                        }
784                        if let Some(func) = tc.function {
785                            if let Some(name) = func.name {
786                                acc.name = name;
787                            }
788                            if let Some(args) = func.arguments {
789                                acc.arguments.push_str(&args);
790                            }
791                        }
792                    }
793                }
794                if ch.finish_reason.is_some() {
795                    finish_reasons[idx] = ch.finish_reason;
796                }
797            }
798        }
799
800        let mut choices = Vec::new();
801        for (idx, content) in contents.into_iter().enumerate() {
802            let finish_reason = finish_reasons.get(idx).cloned().unwrap_or(None);
803            let built_tool_calls = tool_calls
804                .get(idx)
805                .cloned()
806                .unwrap_or_default()
807                .into_iter()
808                .filter(|t| !t.name.trim().is_empty() || !t.arguments.trim().is_empty())
809                .map(|t| {
810                    ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
811                        id: if t.id.trim().is_empty() {
812                            format!("toolcall-{}", idx)
813                        } else {
814                            t.id
815                        },
816                        function: FunctionCall {
817                            name: t.name,
818                            arguments: t.arguments,
819                        },
820                    })
821                })
822                .collect::<Vec<_>>();
823            let tool_calls_opt = if built_tool_calls.is_empty() {
824                None
825            } else {
826                Some(built_tool_calls)
827            };
828            choices.push(ChatChoice {
829                index: idx as u32,
830                message: ChatCompletionResponseMessage {
831                    content: if content.is_empty() {
832                        None
833                    } else {
834                        Some(content)
835                    },
836                    refusal: None,
837                    tool_calls: tool_calls_opt,
838                    annotations: None,
839                    role: Role::Assistant,
840                    function_call: None,
841                    audio: None,
842                },
843                finish_reason,
844                logprobs: None,
845            });
846        }
847        if choices.is_empty() {
848            choices.push(ChatChoice {
849                index: 0,
850                message: ChatCompletionResponseMessage {
851                    content: Some(String::new()),
852                    refusal: None,
853                    tool_calls: None,
854                    annotations: None,
855                    role: Role::Assistant,
856                    function_call: None,
857                    audio: None,
858                },
859                finish_reason: None,
860                logprobs: None,
861            });
862        }
863
864        Ok(CreateChatCompletionResponse {
865            id: id.unwrap_or_else(|| "stream".to_string()),
866            choices,
867            created: created.unwrap_or(0),
868            model: model.unwrap_or_else(|| self.model.to_string()),
869            service_tier,
870            system_fingerprint,
871            object: "chat.completion".to_string(),
872            usage,
873        })
874    }
875
876    pub async fn prompt_once(
877        &self,
878        sys_msg: &str,
879        user_msg: &str,
880        prefix: Option<&str>,
881        settings: Option<LLMSettings>,
882    ) -> Result<CreateChatCompletionResponse, PromptError> {
883        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
884        let sys = ChatCompletionRequestSystemMessageArgs::default()
885            .content(sys_msg)
886            .build()?;
887
888        let user = ChatCompletionRequestUserMessageArgs::default()
889            .content(user_msg)
890            .build()?;
891        let req = CreateChatCompletionRequestArgs::default()
892            .messages(vec![sys.into(), user.into()])
893            .model(self.model.to_string())
894            .temperature(settings.llm_temperature)
895            .presence_penalty(settings.llm_presence_penalty)
896            .max_completion_tokens(settings.llm_max_completion_tokens)
897            .build()?;
898        self.complete(req, prefix).await
899    }
900}