1use std::{
2 fmt::{Debug, Display},
3 ops::{Deref, DerefMut},
4 path::{Path, PathBuf},
5 str::FromStr,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10 time::Duration,
11};
12
13use async_openai::{
14 Client,
15 config::{AzureConfig, OpenAIConfig},
16 error::OpenAIError,
17 types::chat::{
18 ChatCompletionMessageToolCalls, ChatCompletionNamedToolChoiceCustom,
19 ChatCompletionRequestAssistantMessageContent,
20 ChatCompletionRequestAssistantMessageContentPart,
21 ChatCompletionRequestDeveloperMessageContent,
22 ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24 ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25 ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27 ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, ChatCompletionTools,
28 CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
29 CustomName, ToolChoiceOptions,
30 },
31};
32use clap::Args;
33use color_eyre::{
34 Result,
35 eyre::{OptionExt, eyre},
36};
37use itertools::Itertools;
38use log::{debug, info, trace, warn};
39use serde::{Deserialize, Serialize};
40use tokio::{io::AsyncWriteExt, sync::RwLock};
41
42use crate::{OpenAIModel, error::PromptError};
43
44#[derive(Debug, Clone)]
46pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
47
48impl FromStr for LLMToolChoice {
49 type Err = PromptError;
50 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
51 Ok(match s {
52 "auto" => Self(ChatCompletionToolChoiceOption::Mode(
53 ToolChoiceOptions::Auto,
54 )),
55 "required" => Self(ChatCompletionToolChoiceOption::Mode(
56 ToolChoiceOptions::Required,
57 )),
58 "none" => Self(ChatCompletionToolChoiceOption::Mode(
59 ToolChoiceOptions::None,
60 )),
61 _ => Self(ChatCompletionToolChoiceOption::Custom(
62 ChatCompletionNamedToolChoiceCustom {
63 custom: CustomName {
64 name: s.to_string(),
65 },
66 },
67 )),
68 })
69 }
70}
71
72impl Deref for LLMToolChoice {
73 type Target = ChatCompletionToolChoiceOption;
74 fn deref(&self) -> &Self::Target {
75 &self.0
76 }
77}
78
79impl DerefMut for LLMToolChoice {
80 fn deref_mut(&mut self) -> &mut Self::Target {
81 &mut self.0
82 }
83}
84
85impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
86 fn from(value: ChatCompletionToolChoiceOption) -> Self {
87 Self(value)
88 }
89}
90
91impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
92 fn from(value: LLMToolChoice) -> Self {
93 value.0
94 }
95}
96
97#[derive(Args, Clone, Debug)]
98pub struct LLMSettings {
99 #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
100 pub llm_temperature: f32,
101
102 #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
103 pub llm_presence_penalty: f32,
104
105 #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
106 pub llm_prompt_timeout: u64,
107
108 #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
109 pub llm_retry: u64,
110
111 #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
112 pub llm_max_completion_tokens: u32,
113
114 #[arg(long, env = "LLM_TOOL_CHOINCE")]
115 pub llm_tool_choice: Option<LLMToolChoice>,
116}
117
118#[derive(Args, Clone, Debug)]
119pub struct OpenAISetup {
120 #[arg(
121 long,
122 env = "OPENAI_API_URL",
123 default_value = "https://api.openai.com/v1"
124 )]
125 pub openai_url: String,
126
127 #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
128 pub azure_openai_endpoint: Option<String>,
129
130 #[arg(long, env = "OPENAI_API_KEY")]
131 pub openai_key: Option<String>,
132
133 #[arg(long, env = "AZURE_API_DEPLOYMENT")]
134 pub azure_deployment: Option<String>,
135
136 #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
137 pub azure_api_version: String,
138
139 #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
140 pub biling_cap: f64,
141
142 #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
143 pub model: OpenAIModel,
144
145 #[arg(long, env = "LLM_DEBUG")]
146 pub llm_debug: Option<PathBuf>,
147
148 #[clap(flatten)]
149 pub llm_settings: LLMSettings,
150}
151
152impl OpenAISetup {
153 pub fn to_config(&self) -> SupportedConfig {
154 if let Some(ep) = self.azure_openai_endpoint.as_ref() {
155 let cfg = AzureConfig::new()
156 .with_api_base(ep)
157 .with_api_key(self.openai_key.clone().unwrap_or_default())
158 .with_deployment_id(
159 self.azure_deployment
160 .as_ref()
161 .unwrap_or(&self.model.to_string()),
162 )
163 .with_api_version(&self.azure_api_version);
164 SupportedConfig::Azure(cfg)
165 } else {
166 let cfg = OpenAIConfig::new()
167 .with_api_base(&self.openai_url)
168 .with_api_key(self.openai_key.clone().unwrap_or_default());
169 SupportedConfig::OpenAI(cfg)
170 }
171 }
172
173 pub fn to_llm(&self) -> LLM {
174 let billing = RwLock::new(ModelBilling::new(self.biling_cap));
175
176 let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
177 let pid = std::process::id();
178
179 let mut cnt = 0u64;
180 let debug_path;
181 loop {
182 let test_path = dbg.join(format!("{}-{}", pid, cnt));
183 if !test_path.exists() {
184 std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
185 debug_path = Some(test_path);
186 debug!("The path to save LLM interactions is {:?}", &debug_path);
187 break;
188 } else {
189 cnt += 1;
190 }
191 }
192 debug_path
193 } else {
194 None
195 };
196
197 LLM {
198 llm: Arc::new(LLMInner {
199 client: LLMClient::new(self.to_config()),
200 model: self.model.clone(),
201 billing,
202 llm_debug: debug_path,
203 llm_debug_index: AtomicU64::new(0),
204 default_settings: self.llm_settings.clone(),
205 }),
206 }
207 }
208}
209
210#[derive(Debug, Clone)]
211pub enum SupportedConfig {
212 Azure(AzureConfig),
213 OpenAI(OpenAIConfig),
214}
215
216#[derive(Debug, Clone)]
217pub enum LLMClient {
218 Azure(Client<AzureConfig>),
219 OpenAI(Client<OpenAIConfig>),
220}
221
222impl LLMClient {
223 pub fn new(config: SupportedConfig) -> Self {
224 match config {
225 SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
226 SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
227 }
228 }
229
230 pub async fn create_chat(
231 &self,
232 req: CreateChatCompletionRequest,
233 ) -> Result<CreateChatCompletionResponse, OpenAIError> {
234 match self {
235 Self::Azure(cl) => cl.chat().create(req).await,
236 Self::OpenAI(cl) => cl.chat().create(req).await,
237 }
238 }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct ModelBilling {
243 pub current: f64,
244 pub cap: f64,
245}
246
247impl Display for ModelBilling {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
250 }
251}
252
253impl ModelBilling {
254 pub fn new(cap: f64) -> Self {
255 Self { current: 0.0, cap }
256 }
257
258 pub fn in_cap(&self) -> bool {
259 self.current <= self.cap
260 }
261
262 pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
263 let pricing = model.pricing();
264
265 self.current += (pricing.input_tokens * (count as f64)) / 1e6;
266
267 if self.in_cap() {
268 Ok(())
269 } else {
270 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
271 }
272 }
273
274 pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
275 let pricing = model.pricing();
276
277 self.current += pricing.output_tokens * (count as f64) / 1e6;
278
279 if self.in_cap() {
280 Ok(())
281 } else {
282 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
283 }
284 }
285}
286
287#[derive(Debug, Clone)]
288pub struct LLM {
289 pub llm: Arc<LLMInner>,
290}
291
292impl Deref for LLM {
293 type Target = LLMInner;
294
295 fn deref(&self) -> &Self::Target {
296 &self.llm
297 }
298}
299
300#[derive(Debug)]
301pub struct LLMInner {
302 pub client: LLMClient,
303 pub model: OpenAIModel,
304 pub billing: RwLock<ModelBilling>,
305 pub llm_debug: Option<PathBuf>,
306 pub llm_debug_index: AtomicU64,
307 pub default_settings: LLMSettings,
308}
309
310pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
311 match msg {
312 ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
313 ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
314 ChatCompletionRequestMessage::Function(_) => "FUNCTION",
315 ChatCompletionRequestMessage::System(_) => "SYSTEM",
316 ChatCompletionRequestMessage::Tool(_) => "TOOL",
317 ChatCompletionRequestMessage::User(_) => "USER",
318 }
319}
320
321pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
322 match t {
323 ChatCompletionMessageToolCalls::Function(t) => {
324 format!(
325 "<toolcall name=\"{}\">\n{}\n</toolcall>",
326 &t.function.name, &t.function.arguments
327 )
328 }
329 ChatCompletionMessageToolCalls::Custom(t) => {
330 format!(
331 "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
332 &t.custom_tool.name, &t.custom_tool.input
333 )
334 }
335 }
336}
337
338pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
339 let mut s = String::new();
340 if let Some(content) = resp.content.as_ref() {
341 s += content;
342 s += "\n";
343 }
344
345 if let Some(tools) = resp.tool_calls.as_ref() {
346 s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
347 }
348
349 if let Some(refusal) = &resp.refusal {
350 s += refusal;
351 s += "\n";
352 }
353
354 let role = resp.role.to_string().to_uppercase();
355
356 format!("<{}>\n{}\n</{}>\n", &role, s, &role)
357}
358
359pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
360 const CONT: &str = "<cont/>\n";
361 const NONE: &str = "<none/>\n";
362 let role = completion_to_role(msg);
363 let content = match msg {
364 ChatCompletionRequestMessage::Assistant(ass) => {
365 let msg = ass
366 .content
367 .as_ref()
368 .map(|ass| match ass {
369 ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
370 ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
371 .iter()
372 .map(|v| match v {
373 ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
374 s.text.clone()
375 }
376 ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
377 rf.refusal.clone()
378 }
379 })
380 .join(CONT),
381 })
382 .unwrap_or(NONE.to_string());
383 let tool_calls = ass
384 .tool_calls
385 .iter()
386 .flatten()
387 .map(|t| toolcall_to_string(t))
388 .join("\n");
389 format!("{}\n{}", msg, tool_calls)
390 }
391 ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
392 ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
393 ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
394 .iter()
395 .map(|v| match v {
396 ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
397 })
398 .join(CONT),
399 },
400 ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
401 ChatCompletionRequestMessage::System(sys) => match &sys.content {
402 ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
403 ChatCompletionRequestSystemMessageContent::Array(arr) => arr
404 .iter()
405 .map(|v| match v {
406 ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
407 })
408 .join(CONT),
409 },
410 ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
411 ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
412 ChatCompletionRequestToolMessageContent::Array(arr) => arr
413 .iter()
414 .map(|v| match v {
415 ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
416 })
417 .join(CONT),
418 },
419 ChatCompletionRequestMessage::User(usr) => match &usr.content {
420 ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
421 ChatCompletionRequestUserMessageContent::Array(arr) => arr
422 .iter()
423 .map(|v| match v {
424 ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
425 ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
426 format!("<img url=\"{}\"/>", &img.image_url.url)
427 }
428 ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
429 format!("<audio>{}</audio>", audio.input_audio.data)
430 }
431 ChatCompletionRequestUserMessageContentPart::File(f) => {
432 format!("<file>{:?}</file>", f)
433 }
434 })
435 .join(CONT),
436 },
437 };
438
439 format!("<{}>\n{}\n</{}>\n", role, content, role)
440}
441
442impl LLMInner {
443 async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
444 let mut json_fp = fpath.to_path_buf();
445 json_fp.set_file_name(format!(
446 "{}.json",
447 json_fp
448 .file_stem()
449 .ok_or_eyre(eyre!("no filename"))?
450 .to_str()
451 .ok_or_eyre(eyre!("non-utf fname"))?
452 ));
453
454 let mut fp = tokio::fs::OpenOptions::new()
455 .create(true)
456 .append(true)
457 .write(true)
458 .open(&json_fp)
459 .await?;
460 let s = match serde_json::to_string(&t) {
461 Ok(s) => s,
462 Err(_) => format!("{:?}", &t),
463 };
464 fp.write_all(s.as_bytes()).await?;
465 fp.write_all(b"\n").await?;
466 fp.flush().await?;
467
468 Ok(())
469 }
470
471 async fn save_llm_user(
472 fpath: &PathBuf,
473 user_msg: &CreateChatCompletionRequest,
474 ) -> Result<(), PromptError> {
475 let mut fp = tokio::fs::OpenOptions::new()
476 .create(true)
477 .truncate(true)
478 .write(true)
479 .open(&fpath)
480 .await?;
481 fp.write_all(b"=====================\n<Request>\n").await?;
482 for it in user_msg.messages.iter() {
483 let msg = completion_to_string(it);
484 fp.write_all(msg.as_bytes()).await?;
485 }
486
487 let mut tools = vec![];
488 for tool in user_msg
489 .tools
490 .as_ref()
491 .map(|t| t.iter())
492 .into_iter()
493 .flatten()
494 {
495 let s = match tool {
496 ChatCompletionTools::Function(tool) => {
497 format!(
498 "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
499 &tool.function.name,
500 &tool.function.description.clone().unwrap_or_default(),
501 tool.function.strict.unwrap_or_default(),
502 tool.function
503 .parameters
504 .as_ref()
505 .map(serde_json::to_string_pretty)
506 .transpose()?
507 .unwrap_or_default()
508 )
509 }
510 ChatCompletionTools::Custom(tool) => {
511 format!(
512 "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
513 tool.custom.name, tool.custom.description
514 )
515 }
516 };
517 tools.push(s);
518 }
519 fp.write_all(tools.join("\n").as_bytes()).await?;
520 fp.write_all(b"\n</Request>\n=====================\n")
521 .await?;
522 fp.flush().await?;
523
524 Self::rewrite_json(fpath, user_msg).await?;
525
526 Ok(())
527 }
528
529 async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
530 let mut fp = tokio::fs::OpenOptions::new()
531 .create(false)
532 .append(true)
533 .write(true)
534 .open(&fpath)
535 .await?;
536 fp.write_all(b"=====================\n<Response>\n").await?;
537 for it in &resp.choices {
538 let msg = response_to_string(&it.message);
539 fp.write_all(msg.as_bytes()).await?;
540 }
541 fp.write_all(b"\n</Response>\n=====================\n")
542 .await?;
543 fp.flush().await?;
544
545 Self::rewrite_json(fpath, resp).await?;
546
547 Ok(())
548 }
549
550 fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
551 if let Some(output_folder) = self.llm_debug.as_ref() {
552 let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
553 let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
554 Some(fpath)
555 } else {
556 None
557 }
558 }
559
560 pub async fn prompt_once_with_retry(
562 &self,
563 sys_msg: &str,
564 user_msg: &str,
565 prefix: Option<&str>,
566 settings: Option<LLMSettings>,
567 ) -> Result<CreateChatCompletionResponse, PromptError> {
568 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
569 let sys = ChatCompletionRequestSystemMessageArgs::default()
570 .content(sys_msg)
571 .build()?;
572
573 let user = ChatCompletionRequestUserMessageArgs::default()
574 .content(user_msg)
575 .build()?;
576 let mut req = CreateChatCompletionRequestArgs::default();
577 req.messages(vec![sys.into(), user.into()])
578 .model(self.model.to_string())
579 .temperature(settings.llm_temperature)
580 .presence_penalty(settings.llm_presence_penalty)
581 .max_completion_tokens(settings.llm_max_completion_tokens);
582
583 if let Some(tc) = settings.llm_tool_choice {
584 req.tool_choice(tc);
585 }
586 let req = req.build()?;
587
588 let timeout = if settings.llm_prompt_timeout == 0 {
589 Duration::MAX
590 } else {
591 Duration::from_secs(settings.llm_prompt_timeout)
592 };
593
594 self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
595 .await
596 }
597
598 pub async fn complete_once_with_retry(
599 &self,
600 req: &CreateChatCompletionRequest,
601 prefix: Option<&str>,
602 timeout: Option<Duration>,
603 retry: Option<u64>,
604 ) -> Result<CreateChatCompletionResponse, PromptError> {
605 let timeout = if let Some(timeout) = timeout {
606 timeout
607 } else {
608 Duration::MAX
609 };
610
611 let retry = if let Some(retry) = retry {
612 retry
613 } else {
614 u64::MAX
615 };
616
617 let mut last = None;
618 for idx in 0..retry {
619 match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
620 Ok(r) => {
621 last = Some(r);
622 }
623 Err(_) => {
624 warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
625 continue;
626 }
627 };
628
629 match last {
630 Some(Ok(r)) => return Ok(r),
631 Some(Err(ref e)) => {
632 warn!(
633 "Having an error {} during {} retry (timeout is {:?})",
634 e, idx, timeout
635 );
636 }
637 _ => {}
638 }
639 }
640
641 last.ok_or_eyre(eyre!("retry is zero?!"))
642 .map_err(PromptError::Other)?
643 }
644
645 pub async fn complete(
646 &self,
647 req: CreateChatCompletionRequest,
648 prefix: Option<&str>,
649 ) -> Result<CreateChatCompletionResponse, PromptError> {
650 let prefix = if let Some(prefix) = prefix {
651 prefix.to_string()
652 } else {
653 "llm".to_string()
654 };
655 let debug_fp = self.on_llm_debug(&prefix);
656
657 if let Some(debug_fp) = debug_fp.as_ref() {
658 if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
659 warn!("Fail to save user due to {}", e);
660 }
661 }
662
663 trace!(
664 "Sending completion request: {:?}",
665 &serde_json::to_string(&req)
666 );
667 let resp = self.client.create_chat(req).await?;
668
669 if let Some(debug_fp) = debug_fp.as_ref() {
670 if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
671 warn!("Fail to save resp due to {}", e);
672 }
673 }
674
675 if let Some(usage) = &resp.usage {
676 self.billing
677 .write()
678 .await
679 .input_tokens(&self.model, usage.prompt_tokens as u64)
680 .map_err(PromptError::Other)?;
681 self.billing
682 .write()
683 .await
684 .output_tokens(&self.model, usage.completion_tokens as u64)
685 .map_err(PromptError::Other)?;
686 } else {
687 warn!("No usage?!")
688 }
689
690 info!("Model Billing: {}", &self.billing.read().await);
691 Ok(resp)
692 }
693
694 pub async fn prompt_once(
695 &self,
696 sys_msg: &str,
697 user_msg: &str,
698 prefix: Option<&str>,
699 settings: Option<LLMSettings>,
700 ) -> Result<CreateChatCompletionResponse, PromptError> {
701 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
702 let sys = ChatCompletionRequestSystemMessageArgs::default()
703 .content(sys_msg)
704 .build()?;
705
706 let user = ChatCompletionRequestUserMessageArgs::default()
707 .content(user_msg)
708 .build()?;
709 let req = CreateChatCompletionRequestArgs::default()
710 .messages(vec![sys.into(), user.into()])
711 .model(self.model.to_string())
712 .temperature(settings.llm_temperature)
713 .presence_penalty(settings.llm_presence_penalty)
714 .max_completion_tokens(settings.llm_max_completion_tokens)
715 .build()?;
716 self.complete(req, prefix).await
717 }
718}