use candle_core::Device;
use kalosm_language_model::{Embedder, EmbedderExt, Embedding};
use crate::{
Class, ClassificationDataset, ClassificationDatasetBuilder, Classifier, ClassifierConfig,
ClassifierOutput,
};
use super::ClassifierProgress;
pub struct TextClassifierDatasetBuilder<'a, T: Class, E: Embedder> {
dataset: ClassificationDatasetBuilder<T>,
embedder: &'a E,
}
impl<'a, T: Class, E: Embedder> TextClassifierDatasetBuilder<'a, T, E> {
pub fn new(embedder: &'a E) -> Self {
Self {
dataset: ClassificationDatasetBuilder::new(),
embedder,
}
}
pub async fn add(&mut self, text: impl ToString, class: T) -> Result<(), E::Error> {
let embedding = self.embedder.embed(text).await?;
self.dataset
.add(embedding.vector().to_vec().into_boxed_slice(), class);
Ok(())
}
pub async fn extend(
&mut self,
examples: impl IntoIterator<Item = (impl ToString, T)>,
) -> Result<(), E::Error> {
let (texts, classes): (Vec<_>, Vec<_>) = examples.into_iter().unzip();
let embeddings = self.embedder.embed_batch(texts).await?;
for (embedding, class) in embeddings.into_iter().zip(classes) {
self.dataset
.add(embedding.vector().to_vec().into_boxed_slice(), class);
}
Ok(())
}
pub fn build(self, device: &Device) -> candle_core::Result<ClassificationDataset> {
self.dataset.build(device)
}
}
pub struct TextClassifier<T: Class> {
model: Classifier<T>,
}
impl<T: Class> TextClassifier<T> {
pub fn new(model: Classifier<T>) -> Self {
Self { model }
}
pub fn run(&self, input: Embedding) -> candle_core::Result<ClassifierOutput<T>> {
self.model.run(input.vector())
}
pub fn train(
&self,
dataset: &ClassificationDataset,
epochs: usize,
learning_rate: f64,
batch_size: usize,
progress: impl FnMut(ClassifierProgress),
) -> candle_core::Result<f32> {
self.model
.train(dataset, epochs, learning_rate, batch_size, progress)
}
pub fn config(&self) -> ClassifierConfig {
self.model.config()
}
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> candle_core::Result<()> {
self.model.save(path)
}
pub fn load<P: AsRef<std::path::Path>>(
path: P,
device: &Device,
config: ClassifierConfig,
) -> candle_core::Result<Self> {
let model = Classifier::load(path, device, config)?;
Ok(Self::new(model))
}
}
#[cfg(test)]
#[tokio::test]
async fn simplified() -> Result<(), Box<dyn std::error::Error>> {
use crate::{Class, Classifier, ClassifierConfig};
use rbert::{Bert, BertSource};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Class)]
enum MyClass {
Person,
Thing,
}
let bert = Bert::builder()
.with_source(BertSource::snowflake_arctic_embed_extra_small())
.build()
.await?;
let dev = kalosm_common::accelerated_device_if_available()?;
let person_questions = [
"What is the author's name?",
"What is the author's age?",
"Who is the queen of England?",
"Who is the president of the United States?",
"Who is the president of France?",
"Tell me about the CEO of Apple.",
"Who is the CEO of Google?",
"Who is the CEO of Microsoft?",
"What person invented the light bulb?",
"What person invented the telephone?",
"What is the name of the person who invented the light bulb?",
"Who wrote the book 'The Lord of the Rings'?",
"Who wrote the book 'The Hobbit'?",
"How old is the author of the book 'The Lord of the Rings'?",
"How old is the author of the book 'The Hobbit'?",
"Who is the best soccer player in the world?",
"Who is the best basketball player in the world?",
"Who is the best tennis player in the world?",
"Who is the best soccer player in the world right now?",
"Who is the leader of the United States?",
"Who is the leader of France?",
"What is the name of the leader of the United States?",
"What is the name of the leader of France?",
];
let thing_sentences = [
"What is the capital of France?",
"What is the capital of England?",
"What is the name of the biggest city in the world?",
"What tool do you use to cut a tree?",
"What tool do you use to cut a piece of paper?",
"What is a good book to read?",
"What is a good movie to watch?",
"What is a good song to listen to?",
"What is the best tool to use to create a website?",
"What is the best tool to use to create a mobile app?",
"How long does it take to fly from Paris to New York?",
"How do you make a cake?",
"How do you make a pizza?",
"How can you make a website?",
"What is the best way to learn a new language?",
"What is the best way to learn a new programming language?",
"What is a framework?",
"What is a library?",
"What is a good way to learn a new language?",
"What is a good way to learn a new programming language?",
"What is the city with the most people in the world?",
"What is the most spoken language in the world?",
"What is the most spoken language in the United States?",
];
let mut dataset = TextClassifierDatasetBuilder::<MyClass, _>::new(&bert);
for question in &person_questions {
dataset.add(question, MyClass::Person).await?;
}
for sentence in &thing_sentences {
dataset.add(sentence, MyClass::Thing).await?;
}
let dataset = dataset.build(&dev)?;
let mut classifier;
let layers = vec![5, 8, 5];
loop {
classifier = TextClassifier::<MyClass>::new(Classifier::new(
&dev,
ClassifierConfig::new().layers_dims(layers.clone()),
)?);
println!("Training...");
if let Err(error) = classifier.train(&dataset, 100, 0.05, 100, |_| {}) {
println!("Error: {:?}", error);
} else {
break;
}
println!("Retrying...");
}
let config = classifier.model.config();
classifier.save("classifier.safetensors")?;
let classifier = Classifier::<MyClass>::load("classifier.safetensors", &dev, config)?;
let tests = [
"Who is the president of Russia?",
"What is the capital of Russia?",
"Who invented the TV?",
"What is the best way to learn a how to ride a bike?",
];
for test in &tests {
let input = bert.embed(test).await?;
let class = classifier.run(input.vector())?;
println!();
println!("{test}");
println!("{:?} {:?}", &input.vector()[..5], class);
}
Ok(())
}