use std::sync::mpsc;
use anyhow::Result;
use rust_bert::pipelines::sentiment::{Sentiment, SentimentConfig, SentimentModel};
use tokio::{
sync::oneshot,
task::{self, JoinHandle},
};
#[tokio::main]
async fn main() -> Result<()> {
let (_handle, classifier) = SentimentClassifier::spawn();
let texts = vec![
"Classify this positive text".to_owned(),
"Classify this negative text".to_owned(),
];
let sentiments = classifier.predict(texts).await?;
println!("Results: {sentiments:?}");
Ok(())
}
type Message = (Vec<String>, oneshot::Sender<Vec<Sentiment>>);
#[derive(Debug, Clone)]
pub struct SentimentClassifier {
sender: mpsc::SyncSender<Message>,
}
impl SentimentClassifier {
pub fn spawn() -> (JoinHandle<Result<()>>, SentimentClassifier) {
let (sender, receiver) = mpsc::sync_channel(100);
let handle = task::spawn_blocking(move || Self::runner(receiver));
(handle, SentimentClassifier { sender })
}
fn runner(receiver: mpsc::Receiver<Message>) -> Result<()> {
let model = SentimentModel::new(SentimentConfig::default())?;
while let Ok((texts, sender)) = receiver.recv() {
let texts: Vec<&str> = texts.iter().map(String::as_str).collect();
let sentiments = model.predict(texts);
sender.send(sentiments).expect("sending results");
}
Ok(())
}
pub async fn predict(&self, texts: Vec<String>) -> Result<Vec<Sentiment>> {
let (sender, receiver) = oneshot::channel();
self.sender.send((texts, sender))?;
Ok(receiver.await?)
}
}