capability_example/
grower_language_model_client.rs1crate::ix!();
3
4#[derive(Debug)]
7pub struct GrowerLanguageModelClient {
8
9 openai_client: Arc<OpenAIClientHandle<GrowerLanguageModelClientError>>,
11
12 model: LanguageModelType,
14
15 temperature: f32,
17
18 max_tokens: u16,
20}
21
22impl GrowerLanguageModelClient {
23
24 pub fn new() -> Self {
25 Self {
26 openai_client: OpenAIClientHandle::new(),
27 model: LanguageModelType::O1Pro,
28 temperature: 0.7,
29 max_tokens: 8192,
30 }
31 }
32
33 #[instrument(level = "trace", skip(self, query_string))]
37 pub fn run_oneshot_query(&self, query_string: &str) -> Result<String, GrowerLanguageModelClientError> {
38 trace!("Preparing to run a one-shot query with async-openai. Query length: {} chars", query_string.len());
39
40 let rt = match tokio::runtime::Builder::new_current_thread()
42 .enable_all()
43 .build()
44 {
45 Ok(r) => r,
46 Err(e) => {
47 error!("Failed to build a tokio runtime for one-shot query: {:?}", e);
48 return Err(GrowerLanguageModelClientError::FailedToBuildTokioRuntimeForOneShotQuery);
49 }
50 };
51
52 rt.block_on(async {
54 self.run_chat_completion(query_string).await
55 })
56 }
57
58 #[instrument(level = "trace", skip(self, user_text))]
60 async fn run_chat_completion(&self, user_text: &str) -> Result<String, GrowerLanguageModelClientError> {
61
62 trace!("Constructing chat request for model={}", self.model);
63
64 let system_prompt =
65 "You are a skill-tree generator. Please produce valid JSON only, no extraneous text."
66 .to_string();
67
68 let request = CreateChatCompletionRequestArgs::default()
70 .model(self.model.to_string())
71 .max_tokens(self.max_tokens)
72 .temperature(self.temperature)
73 .messages(vec![
74 ChatCompletionRequestMessage::System(
75 ChatCompletionRequestSystemMessage {
76 content: ChatCompletionRequestSystemMessageContent::Text(system_prompt),
77 name: None,
78 }
79 ),
80 ChatCompletionRequestMessage::User(
81 ChatCompletionRequestUserMessage {
82 content: ChatCompletionRequestUserMessageContent::Text(
83 user_text.to_string()
84 ),
85 name: None,
86 }
87 ),
88 ])
89 .build()
90 .map_err(|_e| {
91 error!("Could not build chat completion request.");
92 GrowerLanguageModelClientError::CouldNotBuildChatCompletionRequest
93 })?;
94
95 trace!("Sending request to the OpenAI /v1/chat/completions endpoint...");
96
97 let response = self.openai_client.chat().create(request).await?;
99
100 let content = match response.choices.first() {
102 Some(choice) => {
103 debug!("Successfully got a completion choice from OpenAI");
104 choice.message.content.clone().expect("we expect this to be set")
105 }
106 None => {
107 error!("No choices returned by OpenAI completion");
108 return Err(GrowerLanguageModelClientError::NoChoicesReturnedByOpenAICompletion);
109 }
110 };
111
112 trace!("Returning raw content from OpenAI chat response. length={}", content.len());
113 Ok(content)
114 }
115
116 #[instrument(level = "trace", skip(self, query_string))]
121 pub fn run_oneshot_query_with_repair<TargetType, TargetErrorType>(
122 &self,
123 query_string: &str,
124 ) -> Result<TargetType, TargetErrorType>
125 where
126 TargetType: FuzzyFromJsonValue,
127 TargetErrorType: From<JsonRepairError>
128 + From<GrowerLanguageModelClientError>
129 + From<FuzzyFromJsonValueError>,
130 {
131 trace!("Running one-shot query with JSON repair + fuzzy parse");
132 let language_model_response = self.run_oneshot_query(query_string)?;
134
135 let language_model_response_json: serde_json::Value =
137 repair_json_string(&language_model_response)?;
138
139 let target = TargetType::fuzzy_from_json_value(&language_model_response_json)?;
141
142 Ok(target)
143 }
144}