alith_prompt/
api_prompt.rs1use crate::{PromptTokenizer, token_count::total_prompt_tokens_openai_format};
2use serde::Serialize;
3use std::{
4 collections::HashMap,
5 sync::{Arc, Mutex, MutexGuard},
6};
7
8#[derive(Serialize)]
17pub struct ApiPrompt {
18 #[serde(skip)]
19 tokenizer: Arc<dyn PromptTokenizer>,
20 tokens_per_message: Option<u32>,
21 tokens_per_name: Option<i32>,
22 built_prompt_messages: Mutex<Option<Vec<HashMap<String, String>>>>,
23 total_prompt_tokens: Mutex<Option<u64>>,
24}
25
26impl ApiPrompt {
27 pub fn new(
28 tokenizer: Arc<dyn PromptTokenizer>,
29 tokens_per_message: Option<u32>,
30 tokens_per_name: Option<i32>,
31 ) -> Self {
32 Self {
33 tokenizer,
34 tokens_per_message,
35 tokens_per_name,
36 total_prompt_tokens: None.into(),
37 built_prompt_messages: None.into(),
38 }
39 }
40
41 pub(crate) fn clear_built_prompt(&self) {
45 *self.built_prompt_messages() = None;
46 *self.total_prompt_tokens() = None;
47 }
48
49 pub fn get_built_prompt(&self) -> Result<Vec<HashMap<String, String>>, crate::Error> {
65 match &*self.built_prompt_messages() {
66 Some(prompt) => Ok(prompt.clone()),
67 None => crate::bail!(
68 "ApiPrompt Error - built_prompt_messages not available - prompt not built"
69 ),
70 }
71 }
72
73 pub fn get_total_prompt_tokens(&self) -> Result<u64, crate::Error> {
87 match &*self.total_prompt_tokens() {
88 Some(prompt) => Ok(*prompt),
89 None => crate::bail!(
90 "ApiPrompt Error - total_prompt_tokens not available - prompt not built"
91 ),
92 }
93 }
94
95 pub(crate) fn build_prompt(&self, built_prompt_messages: &[HashMap<String, String>]) {
99 *self.total_prompt_tokens() = Some(total_prompt_tokens_openai_format(
100 built_prompt_messages,
101 self.tokens_per_message,
102 self.tokens_per_name,
103 &self.tokenizer,
104 ));
105
106 *self.built_prompt_messages() = Some(built_prompt_messages.to_vec());
107 }
108
109 fn built_prompt_messages(&self) -> MutexGuard<'_, Option<Vec<HashMap<String, String>>>> {
113 self.built_prompt_messages.lock().unwrap_or_else(|e| {
114 panic!(
115 "ApiPrompt Error - built_prompt_messages not available: {:?}",
116 e
117 )
118 })
119 }
120
121 fn total_prompt_tokens(&self) -> MutexGuard<'_, Option<u64>> {
122 self.total_prompt_tokens.lock().unwrap_or_else(|e| {
123 panic!(
124 "ApiPrompt Error - total_prompt_tokens not available: {:?}",
125 e
126 )
127 })
128 }
129}
130
131impl Clone for ApiPrompt {
132 fn clone(&self) -> Self {
133 Self {
134 tokenizer: self.tokenizer.clone(),
135 tokens_per_message: self.tokens_per_message,
136 tokens_per_name: self.tokens_per_name,
137 total_prompt_tokens: (*self.total_prompt_tokens()).into(),
138 built_prompt_messages: self.built_prompt_messages().clone().into(),
139 }
140 }
141}
142
143impl std::fmt::Display for ApiPrompt {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 writeln!(f)?;
146 writeln!(f, "ApiPrompt")?;
147
148 match *self.total_prompt_tokens() {
149 Some(ref prompt) => {
150 writeln!(f, "total_prompt_tokens:\n\n{}", prompt)?;
151 writeln!(f)?;
152 }
153 None => writeln!(f, "total_prompt_tokens: None")?,
154 };
155
156 Ok(())
157 }
158}