Skip to main content

openai_models/
llm.rs

1use std::{
2    fmt::{Debug, Display},
3    ops::{Deref, DerefMut},
4    path::{Path, PathBuf},
5    str::FromStr,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10    time::Duration,
11};
12
13use async_openai::{
14    Client,
15    config::{AzureConfig, OpenAIConfig},
16    error::OpenAIError,
17    types::chat::{
18        ChatChoice, ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
19        ChatCompletionNamedToolChoiceCustom, ChatCompletionRequestAssistantMessageContent,
20        ChatCompletionRequestAssistantMessageContentPart,
21        ChatCompletionRequestDeveloperMessageContent,
22        ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24        ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25        ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27        ChatCompletionResponseMessage, ChatCompletionResponseStream, ChatCompletionStreamOptions,
28        ChatCompletionToolChoiceOption, ChatCompletionTools, CompletionUsage,
29        CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
30        CreateChatCompletionStreamResponse, CustomName, FinishReason, FunctionCall, Role,
31        ToolChoiceOptions,
32    },
33};
34use clap::Args;
35use color_eyre::{
36    Result,
37    eyre::{OptionExt, eyre},
38};
39use futures_util::StreamExt;
40use itertools::Itertools;
41use log::{debug, info, trace, warn};
42use serde::{Deserialize, Serialize};
43use tokio::{io::AsyncWriteExt, sync::RwLock};
44
45use crate::{OpenAIModel, error::PromptError};
46
47#[derive(Clone, Debug, Default)]
48struct ToolCallAcc {
49    id: String,
50    name: String,
51    arguments: String,
52}
53
54// Upstream implementation is flawed
55#[derive(Debug, Clone)]
56pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
57
58impl FromStr for LLMToolChoice {
59    type Err = PromptError;
60    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
61        Ok(match s {
62            "auto" => Self(ChatCompletionToolChoiceOption::Mode(
63                ToolChoiceOptions::Auto,
64            )),
65            "required" => Self(ChatCompletionToolChoiceOption::Mode(
66                ToolChoiceOptions::Required,
67            )),
68            "none" => Self(ChatCompletionToolChoiceOption::Mode(
69                ToolChoiceOptions::None,
70            )),
71            _ => Self(ChatCompletionToolChoiceOption::Custom(
72                ChatCompletionNamedToolChoiceCustom {
73                    custom: CustomName {
74                        name: s.to_string(),
75                    },
76                },
77            )),
78        })
79    }
80}
81
82impl Deref for LLMToolChoice {
83    type Target = ChatCompletionToolChoiceOption;
84    fn deref(&self) -> &Self::Target {
85        &self.0
86    }
87}
88
89impl DerefMut for LLMToolChoice {
90    fn deref_mut(&mut self) -> &mut Self::Target {
91        &mut self.0
92    }
93}
94
95impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
96    fn from(value: ChatCompletionToolChoiceOption) -> Self {
97        Self(value)
98    }
99}
100
101impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
102    fn from(value: LLMToolChoice) -> Self {
103        value.0
104    }
105}
106
107#[derive(Args, Clone, Debug)]
108pub struct LLMSettings {
109    #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
110    pub llm_temperature: f32,
111
112    #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
113    pub llm_presence_penalty: f32,
114
115    #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
116    pub llm_prompt_timeout: u64,
117
118    #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
119    pub llm_retry: u64,
120
121    #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
122    pub llm_max_completion_tokens: u32,
123
124    #[arg(long, env = "LLM_TOOL_CHOINCE")]
125    pub llm_tool_choice: Option<LLMToolChoice>,
126
127    #[arg(
128        long,
129        env = "LLM_STREAM",
130        default_value_t = false,
131        value_parser = clap::builder::BoolishValueParser::new()
132    )]
133    pub llm_stream: bool,
134}
135
136#[derive(Args, Clone, Debug)]
137pub struct OpenAISetup {
138    #[arg(
139        long,
140        env = "OPENAI_API_URL",
141        default_value = "https://api.openai.com/v1"
142    )]
143    pub openai_url: String,
144
145    #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
146    pub azure_openai_endpoint: Option<String>,
147
148    #[arg(long, env = "OPENAI_API_KEY")]
149    pub openai_key: Option<String>,
150
151    #[arg(long, env = "AZURE_API_DEPLOYMENT")]
152    pub azure_deployment: Option<String>,
153
154    #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
155    pub azure_api_version: String,
156
157    #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
158    pub biling_cap: f64,
159
160    #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
161    pub model: OpenAIModel,
162
163    #[arg(long, env = "LLM_DEBUG")]
164    pub llm_debug: Option<PathBuf>,
165
166    #[clap(flatten)]
167    pub llm_settings: LLMSettings,
168}
169
170impl OpenAISetup {
171    pub fn to_config(&self) -> SupportedConfig {
172        if let Some(ep) = self.azure_openai_endpoint.as_ref() {
173            let cfg = AzureConfig::new()
174                .with_api_base(ep)
175                .with_api_key(self.openai_key.clone().unwrap_or_default())
176                .with_deployment_id(
177                    self.azure_deployment
178                        .as_ref()
179                        .unwrap_or(&self.model.to_string()),
180                )
181                .with_api_version(&self.azure_api_version);
182            SupportedConfig::Azure(cfg)
183        } else {
184            let cfg = OpenAIConfig::new()
185                .with_api_base(&self.openai_url)
186                .with_api_key(self.openai_key.clone().unwrap_or_default());
187            SupportedConfig::OpenAI(cfg)
188        }
189    }
190
191    pub fn to_llm(&self) -> LLM {
192        let billing = RwLock::new(ModelBilling::new(self.biling_cap));
193
194        let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
195            let pid = std::process::id();
196
197            let mut cnt = 0u64;
198            let debug_path;
199            loop {
200                let test_path = dbg.join(format!("{}-{}", pid, cnt));
201                if !test_path.exists() {
202                    std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
203                    debug_path = Some(test_path);
204                    debug!("The path to save LLM interactions is {:?}", &debug_path);
205                    break;
206                } else {
207                    cnt += 1;
208                }
209            }
210            debug_path
211        } else {
212            None
213        };
214
215        LLM {
216            llm: Arc::new(LLMInner {
217                client: LLMClient::new(self.to_config()),
218                model: self.model.clone(),
219                billing,
220                llm_debug: debug_path,
221                llm_debug_index: AtomicU64::new(0),
222                default_settings: self.llm_settings.clone(),
223            }),
224        }
225    }
226}
227
228#[derive(Debug, Clone)]
229pub enum SupportedConfig {
230    Azure(AzureConfig),
231    OpenAI(OpenAIConfig),
232}
233
234#[derive(Debug, Clone)]
235pub enum LLMClient {
236    Azure(Client<AzureConfig>),
237    OpenAI(Client<OpenAIConfig>),
238}
239
240impl LLMClient {
241    pub fn new(config: SupportedConfig) -> Self {
242        match config {
243            SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
244            SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
245        }
246    }
247
248    pub async fn create_chat(
249        &self,
250        req: CreateChatCompletionRequest,
251    ) -> Result<CreateChatCompletionResponse, OpenAIError> {
252        match self {
253            Self::Azure(cl) => cl.chat().create(req).await,
254            Self::OpenAI(cl) => cl.chat().create(req).await,
255        }
256    }
257
258    pub async fn create_chat_stream(
259        &self,
260        req: CreateChatCompletionRequest,
261    ) -> Result<ChatCompletionResponseStream, OpenAIError> {
262        match self {
263            Self::Azure(cl) => cl.chat().create_stream(req).await,
264            Self::OpenAI(cl) => cl.chat().create_stream(req).await,
265        }
266    }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ModelBilling {
271    pub current: f64,
272    pub cap: f64,
273}
274
275impl Display for ModelBilling {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
278    }
279}
280
281impl ModelBilling {
282    pub fn new(cap: f64) -> Self {
283        Self { current: 0.0, cap }
284    }
285
286    pub fn in_cap(&self) -> bool {
287        self.current <= self.cap
288    }
289
290    pub fn input_tokens(
291        &mut self,
292        model: &OpenAIModel,
293        count: u64,
294        cached_count: u64,
295    ) -> Result<()> {
296        let pricing = model.pricing();
297
298        let cached_price = if let Some(cached) = pricing.cached_input_tokens {
299            cached
300        } else {
301            pricing.input_tokens
302        };
303
304        self.current += (cached_price * (cached_count as f64)) / 1e6;
305        self.current += (pricing.input_tokens * (count as f64)) / 1e6;
306
307        if self.in_cap() {
308            Ok(())
309        } else {
310            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
311        }
312    }
313
314    pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
315        let pricing = model.pricing();
316
317        self.current += pricing.output_tokens * (count as f64) / 1e6;
318
319        if self.in_cap() {
320            Ok(())
321        } else {
322            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
323        }
324    }
325}
326
327#[derive(Debug, Clone)]
328pub struct LLM {
329    pub llm: Arc<LLMInner>,
330}
331
332impl Deref for LLM {
333    type Target = LLMInner;
334
335    fn deref(&self) -> &Self::Target {
336        &self.llm
337    }
338}
339
340#[derive(Debug)]
341pub struct LLMInner {
342    pub client: LLMClient,
343    pub model: OpenAIModel,
344    pub billing: RwLock<ModelBilling>,
345    pub llm_debug: Option<PathBuf>,
346    pub llm_debug_index: AtomicU64,
347    pub default_settings: LLMSettings,
348}
349
350pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
351    match msg {
352        ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
353        ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
354        ChatCompletionRequestMessage::Function(_) => "FUNCTION",
355        ChatCompletionRequestMessage::System(_) => "SYSTEM",
356        ChatCompletionRequestMessage::Tool(_) => "TOOL",
357        ChatCompletionRequestMessage::User(_) => "USER",
358    }
359}
360
361pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
362    match t {
363        ChatCompletionMessageToolCalls::Function(t) => {
364            format!(
365                "<toolcall name=\"{}\">\n{}\n</toolcall>",
366                &t.function.name, &t.function.arguments
367            )
368        }
369        ChatCompletionMessageToolCalls::Custom(t) => {
370            format!(
371                "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
372                &t.custom_tool.name, &t.custom_tool.input
373            )
374        }
375    }
376}
377
378pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
379    let mut s = String::new();
380    if let Some(content) = resp.content.as_ref() {
381        s += content;
382        s += "\n";
383    }
384
385    if let Some(tools) = resp.tool_calls.as_ref() {
386        s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
387    }
388
389    if let Some(refusal) = &resp.refusal {
390        s += refusal;
391        s += "\n";
392    }
393
394    let role = resp.role.to_string().to_uppercase();
395
396    format!("<{}>\n{}\n</{}>\n", &role, s, &role)
397}
398
399pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
400    const CONT: &str = "<cont/>\n";
401    const NONE: &str = "<none/>\n";
402    let role = completion_to_role(msg);
403    let content = match msg {
404        ChatCompletionRequestMessage::Assistant(ass) => {
405            let msg = ass
406                .content
407                .as_ref()
408                .map(|ass| match ass {
409                    ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
410                    ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
411                        .iter()
412                        .map(|v| match v {
413                            ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
414                                s.text.clone()
415                            }
416                            ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
417                                rf.refusal.clone()
418                            }
419                        })
420                        .join(CONT),
421                })
422                .unwrap_or(NONE.to_string());
423            let tool_calls = ass
424                .tool_calls
425                .iter()
426                .flatten()
427                .map(|t| toolcall_to_string(t))
428                .join("\n");
429            format!("{}\n{}", msg, tool_calls)
430        }
431        ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
432            ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
433            ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
434                .iter()
435                .map(|v| match v {
436                    ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
437                })
438                .join(CONT),
439        },
440        ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
441        ChatCompletionRequestMessage::System(sys) => match &sys.content {
442            ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
443            ChatCompletionRequestSystemMessageContent::Array(arr) => arr
444                .iter()
445                .map(|v| match v {
446                    ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
447                })
448                .join(CONT),
449        },
450        ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
451            ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
452            ChatCompletionRequestToolMessageContent::Array(arr) => arr
453                .iter()
454                .map(|v| match v {
455                    ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
456                })
457                .join(CONT),
458        },
459        ChatCompletionRequestMessage::User(usr) => match &usr.content {
460            ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
461            ChatCompletionRequestUserMessageContent::Array(arr) => arr
462                .iter()
463                .map(|v| match v {
464                    ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
465                    ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
466                        format!("<img url=\"{}\"/>", &img.image_url.url)
467                    }
468                    ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
469                        format!("<audio>{}</audio>", audio.input_audio.data)
470                    }
471                    ChatCompletionRequestUserMessageContentPart::File(f) => {
472                        format!("<file>{:?}</file>", f)
473                    }
474                })
475                .join(CONT),
476        },
477    };
478
479    format!("<{}>\n{}\n</{}>\n", role, content, role)
480}
481
482impl LLMInner {
483    async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
484        let mut json_fp = fpath.to_path_buf();
485        json_fp.set_file_name(format!(
486            "{}.json",
487            json_fp
488                .file_stem()
489                .ok_or_eyre(eyre!("no filename"))?
490                .to_str()
491                .ok_or_eyre(eyre!("non-utf fname"))?
492        ));
493
494        let mut fp = tokio::fs::OpenOptions::new()
495            .create(true)
496            .append(true)
497            .write(true)
498            .open(&json_fp)
499            .await?;
500        let s = match serde_json::to_string(&t) {
501            Ok(s) => s,
502            Err(_) => format!("{:?}", &t),
503        };
504        fp.write_all(s.as_bytes()).await?;
505        fp.write_all(b"\n").await?;
506        fp.flush().await?;
507
508        Ok(())
509    }
510
511    async fn save_llm_user(
512        fpath: &PathBuf,
513        user_msg: &CreateChatCompletionRequest,
514    ) -> Result<(), PromptError> {
515        let mut fp = tokio::fs::OpenOptions::new()
516            .create(true)
517            .truncate(true)
518            .write(true)
519            .open(&fpath)
520            .await?;
521        fp.write_all(b"=====================\n<Request>\n").await?;
522        for it in user_msg.messages.iter() {
523            let msg = completion_to_string(it);
524            fp.write_all(msg.as_bytes()).await?;
525        }
526
527        let mut tools = vec![];
528        for tool in user_msg
529            .tools
530            .as_ref()
531            .map(|t| t.iter())
532            .into_iter()
533            .flatten()
534        {
535            let s = match tool {
536                ChatCompletionTools::Function(tool) => {
537                    format!(
538                        "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
539                        &tool.function.name,
540                        &tool.function.description.clone().unwrap_or_default(),
541                        tool.function.strict.unwrap_or_default(),
542                        tool.function
543                            .parameters
544                            .as_ref()
545                            .map(serde_json::to_string_pretty)
546                            .transpose()?
547                            .unwrap_or_default()
548                    )
549                }
550                ChatCompletionTools::Custom(tool) => {
551                    format!(
552                        "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
553                        tool.custom.name, tool.custom.description
554                    )
555                }
556            };
557            tools.push(s);
558        }
559        fp.write_all(tools.join("\n").as_bytes()).await?;
560        fp.write_all(b"\n</Request>\n=====================\n")
561            .await?;
562        fp.flush().await?;
563
564        Self::rewrite_json(fpath, user_msg).await?;
565
566        Ok(())
567    }
568
569    async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
570        let mut fp = tokio::fs::OpenOptions::new()
571            .create(false)
572            .append(true)
573            .write(true)
574            .open(&fpath)
575            .await?;
576        fp.write_all(b"=====================\n<Response>\n").await?;
577        for it in &resp.choices {
578            let msg = response_to_string(&it.message);
579            fp.write_all(msg.as_bytes()).await?;
580        }
581        fp.write_all(b"\n</Response>\n=====================\n")
582            .await?;
583        fp.flush().await?;
584
585        Self::rewrite_json(fpath, resp).await?;
586
587        Ok(())
588    }
589
590    fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
591        if let Some(output_folder) = self.llm_debug.as_ref() {
592            let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
593            let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
594            Some(fpath)
595        } else {
596            None
597        }
598    }
599
600    // we use t/s to estimate a timeout to avoid infinite repeating
601    pub async fn prompt_once_with_retry(
602        &self,
603        sys_msg: &str,
604        user_msg: &str,
605        prefix: Option<&str>,
606        settings: Option<LLMSettings>,
607    ) -> Result<CreateChatCompletionResponse, PromptError> {
608        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
609        let sys = ChatCompletionRequestSystemMessageArgs::default()
610            .content(sys_msg)
611            .build()?;
612
613        let user = ChatCompletionRequestUserMessageArgs::default()
614            .content(user_msg)
615            .build()?;
616        let mut req = CreateChatCompletionRequestArgs::default();
617        req.messages(vec![sys.into(), user.into()])
618            .model(self.model.to_string())
619            .temperature(settings.llm_temperature)
620            .presence_penalty(settings.llm_presence_penalty)
621            .max_completion_tokens(settings.llm_max_completion_tokens);
622
623        if let Some(tc) = settings.llm_tool_choice {
624            req.tool_choice(tc);
625        }
626        let req = req.build()?;
627
628        let timeout = if settings.llm_prompt_timeout == 0 {
629            Duration::MAX
630        } else {
631            Duration::from_secs(settings.llm_prompt_timeout)
632        };
633
634        self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
635            .await
636    }
637
638    pub async fn complete_once_with_retry(
639        &self,
640        req: &CreateChatCompletionRequest,
641        prefix: Option<&str>,
642        timeout: Option<Duration>,
643        retry: Option<u64>,
644    ) -> Result<CreateChatCompletionResponse, PromptError> {
645        let timeout = if let Some(timeout) = timeout {
646            timeout
647        } else {
648            Duration::MAX
649        };
650
651        let retry = if let Some(retry) = retry {
652            retry
653        } else {
654            u64::MAX
655        };
656
657        let mut last = None;
658        for idx in 0..retry {
659            match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
660                Ok(r) => {
661                    last = Some(r);
662                }
663                Err(_) => {
664                    warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
665                    continue;
666                }
667            };
668
669            match last {
670                Some(Ok(r)) => return Ok(r),
671                Some(Err(ref e)) => {
672                    warn!(
673                        "Having an error {} during {} retry (timeout is {:?})",
674                        e, idx, timeout
675                    );
676                }
677                _ => {}
678            }
679        }
680
681        last.ok_or_eyre(eyre!("retry is zero?!"))
682            .map_err(PromptError::Other)?
683    }
684
685    pub async fn complete(
686        &self,
687        req: CreateChatCompletionRequest,
688        prefix: Option<&str>,
689    ) -> Result<CreateChatCompletionResponse, PromptError> {
690        let use_stream = self.default_settings.llm_stream;
691        let prefix = if let Some(prefix) = prefix {
692            prefix.to_string()
693        } else {
694            "llm".to_string()
695        };
696        let debug_fp = self.on_llm_debug(&prefix);
697
698        if let Some(debug_fp) = debug_fp.as_ref() {
699            if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
700                warn!("Fail to save user due to {}", e);
701            }
702        }
703
704        trace!(
705            "Sending completion request: {:?}",
706            &serde_json::to_string(&req)
707        );
708        let resp = if use_stream {
709            self.complete_streaming(req).await?
710        } else {
711            self.client.create_chat(req).await?
712        };
713
714        if let Some(debug_fp) = debug_fp.as_ref() {
715            if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
716                warn!("Fail to save resp due to {}", e);
717            }
718        }
719
720        if let Some(usage) = &resp.usage {
721            let cached = usage
722                .prompt_tokens_details
723                .as_ref()
724                .map(|v| v.cached_tokens)
725                .flatten()
726                .unwrap_or_default();
727            let input = usage.prompt_tokens - cached;
728            self.billing
729                .write()
730                .await
731                .input_tokens(&self.model, input as _, cached as _)
732                .map_err(PromptError::Other)?;
733            self.billing
734                .write()
735                .await
736                .output_tokens(&self.model, usage.completion_tokens as u64)
737                .map_err(PromptError::Other)?;
738        } else {
739            warn!("No usage?!")
740        }
741
742        info!("Model Billing: {}", &self.billing.read().await);
743        Ok(resp)
744    }
745
746    async fn complete_streaming(
747        &self,
748        mut req: CreateChatCompletionRequest,
749    ) -> Result<CreateChatCompletionResponse, PromptError> {
750        if req.stream_options.is_none() {
751            req.stream_options = Some(ChatCompletionStreamOptions {
752                include_usage: Some(true),
753                include_obfuscation: None,
754            });
755        }
756
757        let mut stream = self.client.create_chat_stream(req).await?;
758
759        let mut id: Option<String> = None;
760        let mut created: Option<u32> = None;
761        let mut model: Option<String> = None;
762        let mut service_tier = None;
763        let mut system_fingerprint = None;
764        let mut usage: Option<CompletionUsage> = None;
765
766        let mut contents: Vec<String> = Vec::new();
767        let mut finish_reasons: Vec<Option<FinishReason>> = Vec::new();
768        let mut tool_calls: Vec<Vec<ToolCallAcc>> = Vec::new();
769
770        while let Some(item) = stream.next().await {
771            let chunk: CreateChatCompletionStreamResponse = item?;
772            if id.is_none() {
773                id = Some(chunk.id.clone());
774            }
775            created = Some(chunk.created);
776            model = Some(chunk.model.clone());
777            service_tier = chunk.service_tier.clone();
778            system_fingerprint = chunk.system_fingerprint.clone();
779            if let Some(u) = chunk.usage.clone() {
780                usage = Some(u);
781            }
782
783            for ch in chunk.choices.into_iter() {
784                let idx = ch.index as usize;
785                if contents.len() <= idx {
786                    contents.resize_with(idx + 1, String::new);
787                    finish_reasons.resize_with(idx + 1, || None);
788                    tool_calls.resize_with(idx + 1, Vec::new);
789                }
790                if let Some(delta) = ch.delta.content {
791                    contents[idx].push_str(&delta);
792                }
793                if let Some(tcs) = ch.delta.tool_calls {
794                    for tc in tcs.into_iter() {
795                        let tc_idx = tc.index as usize;
796                        if tool_calls[idx].len() <= tc_idx {
797                            tool_calls[idx].resize_with(tc_idx + 1, ToolCallAcc::default);
798                        }
799                        let acc = &mut tool_calls[idx][tc_idx];
800                        if let Some(id) = tc.id {
801                            acc.id = id;
802                        }
803                        if let Some(func) = tc.function {
804                            if let Some(name) = func.name {
805                                acc.name = name;
806                            }
807                            if let Some(args) = func.arguments {
808                                acc.arguments.push_str(&args);
809                            }
810                        }
811                    }
812                }
813                if ch.finish_reason.is_some() {
814                    finish_reasons[idx] = ch.finish_reason;
815                }
816            }
817        }
818
819        let mut choices = Vec::new();
820        for (idx, content) in contents.into_iter().enumerate() {
821            let finish_reason = finish_reasons.get(idx).cloned().unwrap_or(None);
822            let built_tool_calls = tool_calls
823                .get(idx)
824                .cloned()
825                .unwrap_or_default()
826                .into_iter()
827                .filter(|t| !t.name.trim().is_empty() || !t.arguments.trim().is_empty())
828                .map(|t| {
829                    ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
830                        id: if t.id.trim().is_empty() {
831                            format!("toolcall-{}", idx)
832                        } else {
833                            t.id
834                        },
835                        function: FunctionCall {
836                            name: t.name,
837                            arguments: t.arguments,
838                        },
839                    })
840                })
841                .collect::<Vec<_>>();
842            let tool_calls_opt = if built_tool_calls.is_empty() {
843                None
844            } else {
845                Some(built_tool_calls)
846            };
847            choices.push(ChatChoice {
848                index: idx as u32,
849                message: ChatCompletionResponseMessage {
850                    content: if content.is_empty() {
851                        None
852                    } else {
853                        Some(content)
854                    },
855                    refusal: None,
856                    tool_calls: tool_calls_opt,
857                    annotations: None,
858                    role: Role::Assistant,
859                    function_call: None,
860                    audio: None,
861                },
862                finish_reason,
863                logprobs: None,
864            });
865        }
866        if choices.is_empty() {
867            choices.push(ChatChoice {
868                index: 0,
869                message: ChatCompletionResponseMessage {
870                    content: Some(String::new()),
871                    refusal: None,
872                    tool_calls: None,
873                    annotations: None,
874                    role: Role::Assistant,
875                    function_call: None,
876                    audio: None,
877                },
878                finish_reason: None,
879                logprobs: None,
880            });
881        }
882
883        Ok(CreateChatCompletionResponse {
884            id: id.unwrap_or_else(|| "stream".to_string()),
885            choices,
886            created: created.unwrap_or(0),
887            model: model.unwrap_or_else(|| self.model.to_string()),
888            service_tier,
889            system_fingerprint,
890            object: "chat.completion".to_string(),
891            usage,
892        })
893    }
894
895    pub async fn prompt_once(
896        &self,
897        sys_msg: &str,
898        user_msg: &str,
899        prefix: Option<&str>,
900        settings: Option<LLMSettings>,
901    ) -> Result<CreateChatCompletionResponse, PromptError> {
902        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
903        let sys = ChatCompletionRequestSystemMessageArgs::default()
904            .content(sys_msg)
905            .build()?;
906
907        let user = ChatCompletionRequestUserMessageArgs::default()
908            .content(user_msg)
909            .build()?;
910        let req = CreateChatCompletionRequestArgs::default()
911            .messages(vec![sys.into(), user.into()])
912            .model(self.model.to_string())
913            .temperature(settings.llm_temperature)
914            .presence_penalty(settings.llm_presence_penalty)
915            .max_completion_tokens(settings.llm_max_completion_tokens)
916            .build()?;
917        self.complete(req, prefix).await
918    }
919}