Struct TextClassifier

Source
pub struct TextClassifier<T: Class> { /* private fields */ }
Expand description

A text classifier.

§Example

use candle_core::Device;
use kalosm_language_model::{Embedder, EmbedderExt};
use kalosm_learning::{
    Class, Classifier, ClassifierConfig, TextClassifier, TextClassifierDatasetBuilder,
};
use rbert::Bert;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Class)]
    enum MyClass {
        Person,
        Thing,
    }

    let mut bert = Bert::builder().build().await?;

    let dev = Device::cuda_if_available(0)?;
    let person_questions = vec![
        "What is the author's name?",
        "What is the author's age?",
        "Who is the queen of England?",
        "Who is the president of the United States?",
        "Who is the president of France?",
        "Tell me about the CEO of Apple.",
        "Who is the CEO of Google?",
        "Who is the CEO of Microsoft?",
        "What person invented the light bulb?",
        "What person invented the telephone?",
        "What is the name of the person who invented the light bulb?",
        "Who wrote the book 'The Lord of the Rings'?",
        "Who wrote the book 'The Hobbit'?",
        "How old is the author of the book 'The Lord of the Rings'?",
        "How old is the author of the book 'The Hobbit'?",
        "Who is the best soccer player in the world?",
        "Who is the best basketball player in the world?",
        "Who is the best tennis player in the world?",
        "Who is the best soccer player in the world right now?",
        "Who is the leader of the United States?",
        "Who is the leader of France?",
        "What is the name of the leader of the United States?",
        "What is the name of the leader of France?",
    ];
    let thing_sentences = vec![
        "What is the capital of France?",
        "What is the capital of England?",
        "What is the name of the biggest city in the world?",
        "What tool do you use to cut a tree?",
        "What tool do you use to cut a piece of paper?",
        "What is a good book to read?",
        "What is a good movie to watch?",
        "What is a good song to listen to?",
        "What is the best tool to use to create a website?",
        "What is the best tool to use to create a mobile app?",
        "How long does it take to fly from Paris to New York?",
        "How do you make a cake?",
        "How do you make a pizza?",
        "How can you make a website?",
        "What is the best way to learn a new language?",
        "What is the best way to learn a new programming language?",
        "What is a framework?",
        "What is a library?",
        "What is a good way to learn a new language?",
        "What is a good way to learn a new programming language?",
        "What is the city with the most people in the world?",
        "What is the most spoken language in the world?",
        "What is the most spoken language in the United States?",
    ];

    let mut dataset = TextClassifierDatasetBuilder::<MyClass, _>::new(&mut bert);

    for question in &person_questions {
        dataset.add(question, MyClass::Person).await?;
    }

    for sentence in &thing_sentences {
        dataset.add(sentence, MyClass::Thing).await?;
    }

    let dataset = dataset.build(&dev)?;

    let mut classifier;
    let layers = vec![5, 8, 5];

    loop {
        classifier = TextClassifier::<MyClass>::new(Classifier::new(
            &dev,
            ClassifierConfig::new().layers_dims(layers.clone()),
        )?);
        if let Err(error) = classifier.train(&dataset, 100, 0.05, 3, |_| {}) {
            println!("Error: {:?}", error);
        } else {
            break;
        }
        println!("Retrying...");
    }

    let config = classifier.config();
    classifier.save("classifier.safetensors")?;
    let classifier = Classifier::<MyClass>::load("classifier.safetensors", &dev, config)?;

    let tests = [
        "Who is the president of Russia?",
        "What is the capital of Russia?",
        "Who invented the TV?",
        "What is the best way to learn a how to ride a bike?",
    ];

    for test in &tests {
        let input = bert.embed(test).await?;
        let class = classifier.run(input.vector())?;
        println!();
        println!("{test}");
        println!("{:?} {:?}", &input.vector()[..5], class);
    }

    Ok(())
}

Implementations§

Source§

impl<T: Class> TextClassifier<T>

Source

pub fn new(model: Classifier<T>) -> Self

Creates a new TextClassifier.

Source

pub fn run(&self, input: Embedding) -> Result<ClassifierOutput<T>>

Runs the classifier on the given input.

Source

pub fn train( &self, dataset: &ClassificationDataset, epochs: usize, learning_rate: f64, batch_size: usize, progress: impl FnMut(ClassifierProgress), ) -> Result<f32>

Trains the classifier on the given dataset.

Source

pub fn config(&self) -> ClassifierConfig

Get the configuration of the classifier.

Source

pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()>

Saves the classifier to the given path.

Source

pub fn load<P: AsRef<Path>>( path: P, device: &Device, config: ClassifierConfig, ) -> Result<Self>

Loads a classifier from the given path.

Auto Trait Implementations§

§

impl<T> !Freeze for TextClassifier<T>

§

impl<T> !RefUnwindSafe for TextClassifier<T>

§

impl<T> Send for TextClassifier<T>
where T: Send,

§

impl<T> Sync for TextClassifier<T>
where T: Sync,

§

impl<T> Unpin for TextClassifier<T>
where T: Unpin,

§

impl<T> !UnwindSafe for TextClassifier<T>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> ErasedDestructor for T
where T: 'static,

Source§

impl<T> ErasedDestructor for T
where T: 'static,