pub struct TextClassifier<T: Class> { /* private fields */ }
Expand description
A text classifier.
§Example
use candle_core::Device;
use kalosm_language_model::{Embedder, EmbedderExt};
use kalosm_learning::{
Class, Classifier, ClassifierConfig, TextClassifier, TextClassifierDatasetBuilder,
};
use rbert::Bert;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
#[derive(Debug, Copy, Clone, PartialEq, Eq, Class)]
enum MyClass {
Person,
Thing,
}
let mut bert = Bert::builder().build().await?;
let dev = Device::cuda_if_available(0)?;
let person_questions = vec![
"What is the author's name?",
"What is the author's age?",
"Who is the queen of England?",
"Who is the president of the United States?",
"Who is the president of France?",
"Tell me about the CEO of Apple.",
"Who is the CEO of Google?",
"Who is the CEO of Microsoft?",
"What person invented the light bulb?",
"What person invented the telephone?",
"What is the name of the person who invented the light bulb?",
"Who wrote the book 'The Lord of the Rings'?",
"Who wrote the book 'The Hobbit'?",
"How old is the author of the book 'The Lord of the Rings'?",
"How old is the author of the book 'The Hobbit'?",
"Who is the best soccer player in the world?",
"Who is the best basketball player in the world?",
"Who is the best tennis player in the world?",
"Who is the best soccer player in the world right now?",
"Who is the leader of the United States?",
"Who is the leader of France?",
"What is the name of the leader of the United States?",
"What is the name of the leader of France?",
];
let thing_sentences = vec![
"What is the capital of France?",
"What is the capital of England?",
"What is the name of the biggest city in the world?",
"What tool do you use to cut a tree?",
"What tool do you use to cut a piece of paper?",
"What is a good book to read?",
"What is a good movie to watch?",
"What is a good song to listen to?",
"What is the best tool to use to create a website?",
"What is the best tool to use to create a mobile app?",
"How long does it take to fly from Paris to New York?",
"How do you make a cake?",
"How do you make a pizza?",
"How can you make a website?",
"What is the best way to learn a new language?",
"What is the best way to learn a new programming language?",
"What is a framework?",
"What is a library?",
"What is a good way to learn a new language?",
"What is a good way to learn a new programming language?",
"What is the city with the most people in the world?",
"What is the most spoken language in the world?",
"What is the most spoken language in the United States?",
];
let mut dataset = TextClassifierDatasetBuilder::<MyClass, _>::new(&mut bert);
for question in &person_questions {
dataset.add(question, MyClass::Person).await?;
}
for sentence in &thing_sentences {
dataset.add(sentence, MyClass::Thing).await?;
}
let dataset = dataset.build(&dev)?;
let mut classifier;
let layers = vec![5, 8, 5];
loop {
classifier = TextClassifier::<MyClass>::new(Classifier::new(
&dev,
ClassifierConfig::new().layers_dims(layers.clone()),
)?);
if let Err(error) = classifier.train(&dataset, 100, 0.05, 3, |_| {}) {
println!("Error: {:?}", error);
} else {
break;
}
println!("Retrying...");
}
let config = classifier.config();
classifier.save("classifier.safetensors")?;
let classifier = Classifier::<MyClass>::load("classifier.safetensors", &dev, config)?;
let tests = [
"Who is the president of Russia?",
"What is the capital of Russia?",
"Who invented the TV?",
"What is the best way to learn a how to ride a bike?",
];
for test in &tests {
let input = bert.embed(test).await?;
let class = classifier.run(input.vector())?;
println!();
println!("{test}");
println!("{:?} {:?}", &input.vector()[..5], class);
}
Ok(())
}
Implementations§
Source§impl<T: Class> TextClassifier<T>
impl<T: Class> TextClassifier<T>
Sourcepub fn new(model: Classifier<T>) -> Self
pub fn new(model: Classifier<T>) -> Self
Creates a new TextClassifier
.
Sourcepub fn run(&self, input: Embedding) -> Result<ClassifierOutput<T>>
pub fn run(&self, input: Embedding) -> Result<ClassifierOutput<T>>
Runs the classifier on the given input.
Sourcepub fn train(
&self,
dataset: &ClassificationDataset,
epochs: usize,
learning_rate: f64,
batch_size: usize,
progress: impl FnMut(ClassifierProgress),
) -> Result<f32>
pub fn train( &self, dataset: &ClassificationDataset, epochs: usize, learning_rate: f64, batch_size: usize, progress: impl FnMut(ClassifierProgress), ) -> Result<f32>
Trains the classifier on the given dataset.
Sourcepub fn config(&self) -> ClassifierConfig
pub fn config(&self) -> ClassifierConfig
Get the configuration of the classifier.
Auto Trait Implementations§
impl<T> !Freeze for TextClassifier<T>
impl<T> !RefUnwindSafe for TextClassifier<T>
impl<T> Send for TextClassifier<T>where
T: Send,
impl<T> Sync for TextClassifier<T>where
T: Sync,
impl<T> Unpin for TextClassifier<T>where
T: Unpin,
impl<T> !UnwindSafe for TextClassifier<T>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more