pub mod chat;
pub mod json;
pub mod xml;
use crate::error::Result;
use crate::intern::Sym;
use crate::predict::FieldRange;
use crate::signature::Signature;
use crate::str_view::StrView;
use crate::types::Inputs;
use smallvec::SmallVec;
pub trait Adapter: Send + Sync {
fn format<'a>(
&self,
buffer: &'a mut Vec<u8>,
signature: &Signature<'_>,
inputs: &Inputs<'_>,
demos: &[DemoData<'_>],
) -> StrView<'a>;
fn parse<'a>(
&self,
response: StrView<'a>,
signature: &Signature<'_>,
) -> Result<SmallVec<[(Sym, FieldRange); 4]>>;
fn name(&self) -> &'static str;
}
#[derive(Clone, Copy)]
pub struct DemoData<'a> {
pub inputs: &'a [(Sym, StrView<'a>)],
pub outputs: &'a [(Sym, StrView<'a>)],
}
impl<'a> DemoData<'a> {
pub const fn new(inputs: &'a [(Sym, StrView<'a>)], outputs: &'a [(Sym, StrView<'a>)]) -> Self {
Self { inputs, outputs }
}
pub fn get_input(&self, sym: Sym) -> Option<StrView<'a>> {
self.inputs.iter().find(|(s, _)| *s == sym).map(|(_, v)| *v)
}
pub fn get_output(&self, sym: Sym) -> Option<StrView<'a>> {
self.outputs
.iter()
.find(|(s, _)| *s == sym)
.map(|(_, v)| *v)
}
}
#[inline]
pub fn format_field_name(buffer: &mut Vec<u8>, name: &str) {
let mut chars = name.chars();
if let Some(first) = chars.next() {
for c in first.to_uppercase() {
buffer.push(c as u8);
}
}
for c in chars {
if c == '_' {
buffer.push(b' ');
} else {
buffer.push(c as u8);
}
}
}
fn formatted_field_name(name: &str) -> String {
let mut result = String::with_capacity(name.len());
let mut chars = name.chars();
if let Some(first) = chars.next() {
for c in first.to_uppercase() {
result.push(c);
}
}
for c in chars {
if c == '_' {
result.push(' ');
} else {
result.push(c);
}
}
result
}
pub fn find_field_value(text: &str, field_name: &str) -> Option<std::ops::Range<usize>> {
let formatted = formatted_field_name(field_name);
let patterns = [
format!("{}: ", formatted),
format!("{}:", formatted),
format!("[{}] ", formatted),
format!("[{}]", formatted),
format!("{}: ", field_name),
format!("{}:", field_name),
format!("[{}] ", field_name),
format!("[{}]", field_name),
];
for pattern in &patterns {
if let Some(start) = text.find(pattern.as_str()) {
let value_start = start + pattern.len();
let value_end = find_value_end(text, value_start);
if value_start < value_end {
return Some(value_start..value_end);
}
}
}
None
}
fn find_value_end(text: &str, start: usize) -> usize {
let remaining = &text[start..];
let markers = ["\n\n", "\n[", "\nAnswer:", "\nQuestion:", "\nReasoning:"];
let mut end = remaining.len();
for marker in markers {
if let Some(pos) = remaining.find(marker) {
end = end.min(pos);
}
}
let value = &remaining[..end];
start + value.trim_end().len()
}
pub use chat::{ChatAdapter, ChatConfig};
pub use json::{JSONAdapter, JSONConfig};
pub use xml::{XMLAdapter, XMLConfig};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_field_name() {
let mut buffer = Vec::new();
format_field_name(&mut buffer, "question");
assert_eq!(String::from_utf8(buffer).unwrap(), "Question");
let mut buffer = Vec::new();
format_field_name(&mut buffer, "chain_of_thought");
assert_eq!(String::from_utf8(buffer).unwrap(), "Chain of thought");
}
#[test]
fn test_find_field_value() {
let text = "Question: What is 2+2?\nAnswer: 4";
let range = find_field_value(text, "Question");
assert!(range.is_some());
assert_eq!(&text[range.unwrap()], "What is 2+2?");
let range = find_field_value(text, "Answer");
assert!(range.is_some());
assert_eq!(&text[range.unwrap()], "4");
}
#[test]
fn test_demo_data() {
use crate::intern::sym;
let q_sym = sym("question");
let a_sym = sym("answer");
let inputs = [(q_sym, StrView::new("What is 2+2?"))];
let outputs = [(a_sym, StrView::new("4"))];
let demo = DemoData::new(&inputs, &outputs);
assert_eq!(
demo.get_input(q_sym).map(|v| v.as_str()),
Some("What is 2+2?")
);
assert_eq!(demo.get_output(a_sym).map(|v| v.as_str()), Some("4"));
}
}