use crate::errors::InferenceError;
use crate::formatter::{SlmFormatter, SlmToolStyle};
use crate::{
SlmAnswer, SlmBoxedBrakeFn, SlmBrake, SlmContext, SlmInference, SlmRole, SlmSimpleInference,
};
const DEFAULT_MAX_ANSWER_TOKENS: usize = 1024;
pub trait SlmOracle {
fn prompt(&mut self, role: &SlmRole, text: &str) -> Result<(), InferenceError>;
fn generate(
&mut self,
role: &SlmRole,
text: &str,
think: bool,
brake: Option<SlmBoxedBrakeFn>,
) -> Result<SlmAnswer, InferenceError>;
fn clear(&mut self) -> Result<(), InferenceError>;
fn system(&mut self, text: &str) -> Result<(), InferenceError> {
self.prompt(&SlmRole::System, text)
}
fn user(&mut self, text: &str) -> Result<(), InferenceError> {
self.prompt(&SlmRole::User, text)
}
fn assistant(&mut self, text: &str) -> Result<(), InferenceError> {
self.prompt(&SlmRole::Assistant, text)
}
fn tool(&mut self, tool_name: &str, text: &str) -> Result<(), InferenceError> {
self.prompt(&SlmRole::tool(tool_name), text)
}
fn ask(
&mut self,
text: &str,
brake: Option<SlmBoxedBrakeFn>,
) -> Result<SlmAnswer, InferenceError> {
self.generate(&SlmRole::User, text, false, brake)
}
fn think(
&mut self,
text: &str,
brake: Option<SlmBoxedBrakeFn>,
) -> Result<SlmAnswer, InferenceError> {
self.generate(&SlmRole::User, text, true, brake)
}
}
pub struct SlmSimpleOracle<I: SlmInference, F: SlmFormatter> {
inference: I,
formatter: F,
max_answer_tokens: usize,
is_fresh_context: bool,
active_turn: Option<SlmRole>,
}
struct SavePoint<'a>(&'a mut dyn SlmInference);
impl Drop for SavePoint<'_> {
fn drop(&mut self) {
self.0.rollback().unwrap();
}
}
impl<C: SlmContext, F: SlmFormatter> SlmSimpleOracle<SlmSimpleInference<C>, F> {
pub fn new(context: C, formatter: F) -> Result<Self, InferenceError> {
let inference = SlmSimpleInference::new(context)?;
Ok(Self {
inference,
formatter,
max_answer_tokens: DEFAULT_MAX_ANSWER_TOKENS,
is_fresh_context: true,
active_turn: None,
})
}
}
impl<I: SlmInference, F: SlmFormatter> SlmSimpleOracle<I, F> {
fn bos(&mut self, s: &mut String) {
if self.is_fresh_context {
if let Some(bos) = self.formatter.bos() {
s.push_str(bos);
}
self.is_fresh_context = false;
}
}
}
impl<I: SlmInference, F: SlmFormatter> SlmSimpleOracle<I, F> {
fn prepare_prompt(
&mut self,
role: &SlmRole,
text: &str,
fragment: &mut String,
) -> Result<(), InferenceError> {
self.bos(fragment);
match self.formatter.tool_style() {
SlmToolStyle::Inline => {
match role {
SlmRole::System | SlmRole::User => {
if self.active_turn == Some(SlmRole::Assistant) {
fragment.push_str(&self.formatter.turn_end(&SlmRole::Assistant));
}
let role_clone = Some(role.clone());
if self.active_turn != role_clone {
if let Some(active_role) = &self.active_turn {
fragment.push_str(&self.formatter.turn_end(active_role));
}
self.active_turn = role_clone;
fragment.push_str(&self.formatter.turn_start(role));
}
fragment.push_str(text);
}
SlmRole::Assistant => {
if self.active_turn != Some(SlmRole::Assistant) {
fragment.push_str(&self.formatter.turn_start(&SlmRole::Assistant));
self.active_turn = Some(SlmRole::Assistant);
}
fragment.push_str(text);
}
SlmRole::Tool(tool_name) => {
if self.active_turn != Some(SlmRole::Assistant) {
fragment.push_str(&self.formatter.turn_start(&SlmRole::Assistant));
self.active_turn = Some(SlmRole::Assistant);
}
fragment.push_str(&self.formatter.format_tool_response(tool_name, text));
}
}
}
SlmToolStyle::SeparateTurn => {
if let Some(active_role) = &self.active_turn
&& active_role != role
{
fragment.push_str(&self.formatter.turn_end(&active_role));
fragment.push_str(&self.formatter.turn_start(&active_role));
}
if let SlmRole::Tool(tool_name) = role {
fragment.push_str(&self.formatter.format_tool_response(tool_name, text));
} else {
fragment.push_str(text);
}
self.active_turn = Some(role.clone());
}
}
Ok(())
}
}
impl<I: SlmInference, F: SlmFormatter> SlmOracle for SlmSimpleOracle<I, F> {
fn prompt(&mut self, role: &SlmRole, text: &str) -> Result<(), InferenceError> {
let mut fragment = String::new();
self.prepare_prompt(role, text, &mut fragment)?;
self.inference.prefill(&fragment)
}
fn generate(
&mut self,
role: &SlmRole,
text: &str,
think: bool,
brake: Option<SlmBoxedBrakeFn>,
) -> Result<SlmAnswer, InferenceError> {
let mut fragment = String::new();
if role == &SlmRole::Assistant || role == &SlmRole::System {
return Err(InferenceError::InvalidRole);
}
self.prepare_prompt(role, text, &mut fragment)?;
if self.active_turn != Some(SlmRole::Assistant) {
fragment.push_str(&self.formatter.turn_end(role));
fragment.push_str(&self.formatter.turn_start(&SlmRole::Assistant));
self.active_turn = Some(SlmRole::Assistant);
}
if think {
fragment.push_str(self.formatter.reasoning_trigger().unwrap_or(""));
}
self.inference.save()?;
let _ = SavePoint(&mut self.inference);
self.inference.prefill(&fragment)?;
let mut answer = self
.inference
.generate_until(&mut [brake, Some(SlmBrake::token_limit(self.max_answer_tokens))])?;
if think {
answer =
answer.map(|s| self.formatter.reasoning_trigger().unwrap_or("").to_string() + &s);
}
Ok(answer.split_thought(&self.formatter))
}
fn clear(&mut self) -> Result<(), InferenceError> {
self.is_fresh_context = true;
self.active_turn = None;
self.inference.clear()
}
}