crate::ix!();
#[derive(Debug)]
pub struct GrowerLanguageModelClient {
openai_client: Arc<OpenAIClientHandle<GrowerLanguageModelClientError>>,
model: LanguageModelType,
temperature: f32,
max_tokens: u16,
}
impl GrowerLanguageModelClient {
pub fn new() -> Self {
Self {
openai_client: OpenAIClientHandle::new(),
model: LanguageModelType::O1Pro,
temperature: 0.7,
max_tokens: 8192,
}
}
#[instrument(level = "trace", skip(self, query_string))]
pub fn run_oneshot_query(&self, query_string: &str) -> Result<String, GrowerLanguageModelClientError> {
trace!("Preparing to run a one-shot query with async-openai. Query length: {} chars", query_string.len());
let rt = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(r) => r,
Err(e) => {
error!("Failed to build a tokio runtime for one-shot query: {:?}", e);
return Err(GrowerLanguageModelClientError::FailedToBuildTokioRuntimeForOneShotQuery);
}
};
rt.block_on(async {
self.run_chat_completion(query_string).await
})
}
#[instrument(level = "trace", skip(self, user_text))]
async fn run_chat_completion(&self, user_text: &str) -> Result<String, GrowerLanguageModelClientError> {
trace!("Constructing chat request for model={}", self.model);
let system_prompt =
"You are a skill-tree generator. Please produce valid JSON only, no extraneous text."
.to_string();
let request = CreateChatCompletionRequestArgs::default()
.model(self.model.to_string())
.max_tokens(self.max_tokens)
.temperature(self.temperature)
.messages(vec![
ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(system_prompt),
name: None,
}
),
ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(
user_text.to_string()
),
name: None,
}
),
])
.build()
.map_err(|_e| {
error!("Could not build chat completion request.");
GrowerLanguageModelClientError::CouldNotBuildChatCompletionRequest
})?;
trace!("Sending request to the OpenAI /v1/chat/completions endpoint...");
let response = self.openai_client.chat().create(request).await?;
let content = match response.choices.first() {
Some(choice) => {
debug!("Successfully got a completion choice from OpenAI");
choice.message.content.clone().expect("we expect this to be set")
}
None => {
error!("No choices returned by OpenAI completion");
return Err(GrowerLanguageModelClientError::NoChoicesReturnedByOpenAICompletion);
}
};
trace!("Returning raw content from OpenAI chat response. length={}", content.len());
Ok(content)
}
#[instrument(level = "trace", skip(self, query_string))]
pub fn run_oneshot_query_with_repair<TargetType, TargetErrorType>(
&self,
query_string: &str,
) -> Result<TargetType, TargetErrorType>
where
TargetType: FuzzyFromJsonValue,
TargetErrorType: From<JsonRepairError>
+ From<GrowerLanguageModelClientError>
+ From<FuzzyFromJsonValueError>,
{
trace!("Running one-shot query with JSON repair + fuzzy parse");
let language_model_response = self.run_oneshot_query(query_string)?;
let language_model_response_json: serde_json::Value =
repair_json_string(&language_model_response)?;
let target = TargetType::fuzzy_from_json_value(&language_model_response_json)?;
Ok(target)
}
}