use crate::progress::{Options as ProgressOptions, Progress};
use anyhow::{Context, Result};
use colored::Colorize;
use log::info;
use reinfer_client::{
resources::comment::{should_skip_serializing_optional_vec, EitherLabelling, HasAnnotations},
Client, CommentId, CommentUid, DatasetFullName, DatasetIdentifier, NewEntities, NewLabelling,
NewMoonForm, Source, SourceIdentifier,
};
use scoped_threadpool::Pool;
use serde::{Deserialize, Serialize};
use std::sync::mpsc::channel;
use std::{
fs::File,
io::{self, BufRead, BufReader},
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use structopt::StructOpt;
#[derive(Debug, StructOpt)]
pub struct CreateAnnotationsArgs {
#[structopt(short = "f", long = "file", parse(from_os_str))]
annotations_path: Option<PathBuf>,
#[structopt(short = "s", long = "source")]
source: SourceIdentifier,
#[structopt(short = "d", long = "dataset")]
dataset: DatasetIdentifier,
#[structopt(long)]
no_progress: bool,
#[structopt(long)]
use_moon_forms: bool,
#[structopt(long = "batch-size", default_value = "128")]
batch_size: usize,
}
pub fn create(client: &Client, args: &CreateAnnotationsArgs, pool: &mut Pool) -> Result<()> {
let source = client
.get_source(args.source.clone())
.with_context(|| format!("Unable to get source {}", args.source))?;
let source_name = source.full_name();
let dataset = client
.get_dataset(args.dataset.clone())
.with_context(|| format!("Unable to get dataset {}", args.dataset))?;
let dataset_name = dataset.full_name();
let statistics = match &args.annotations_path {
Some(annotations_path) => {
info!(
"Uploading comments from file `{}` to source `{}` [id: {}] and dataset `{}` [id: {}]",
annotations_path.display(),
source_name.0,
source.id.0,
dataset_name.0,
dataset.id.0,
);
let file = BufReader::new(File::open(annotations_path).with_context(|| {
format!("Could not open file `{}`", annotations_path.display())
})?);
let file_metadata = file.get_ref().metadata().with_context(|| {
format!(
"Could not get file metadata for `{}`",
annotations_path.display()
)
})?;
let statistics = Arc::new(Statistics::new());
let progress = if args.no_progress {
None
} else {
Some(progress_bar(file_metadata.len(), &statistics))
};
upload_annotations_from_reader(
client,
&source,
file,
&statistics,
&dataset_name,
args.use_moon_forms,
args.batch_size,
pool,
)?;
if let Some(mut progress) = progress {
progress.done();
}
Arc::try_unwrap(statistics)
.expect("Not all references to `statistics` have been disposed of")
}
None => {
info!(
"Uploading annotations from stdin to source `{}` [id: {}] and dataset `{} [id: {}]",
source_name.0, source.id.0, dataset_name.0, dataset.id.0
);
let statistics = Statistics::new();
upload_annotations_from_reader(
client,
&source,
BufReader::new(io::stdin()),
&statistics,
&dataset_name,
args.use_moon_forms,
args.batch_size,
pool,
)?;
statistics
}
};
info!(
"Successfully uploaded {} annotations.",
statistics.num_annotations(),
);
Ok(())
}
pub trait AnnotationStatistic {
fn add_annotation(&self);
}
pub fn upload_batch_of_annotations(
annotations_to_upload: &mut Vec<NewAnnotation>,
client: &Client,
source: &Source,
statistics: &(impl AnnotationStatistic + std::marker::Sync),
dataset_name: &DatasetFullName,
use_moon_forms: bool,
pool: &mut Pool,
) -> Result<()> {
let (error_sender, error_receiver) = channel();
pool.scoped(|scope| {
annotations_to_upload.iter().for_each(|new_comment| {
let error_sender = error_sender.clone();
scope.execute(move || {
let comment_uid =
CommentUid(format!("{}.{}", source.id.0, new_comment.comment.id.0));
let result = (if !use_moon_forms {
client.update_labelling(
dataset_name,
&comment_uid,
new_comment
.labelling
.clone()
.map(Into::<Vec<NewLabelling>>::into)
.as_deref(),
new_comment.entities.as_ref(),
None,
)
} else {
client.update_labelling(
dataset_name,
&comment_uid,
None,
None,
new_comment.moon_forms.as_deref(),
)
})
.with_context(|| {
format!(
"Could not update labelling for comment `{}`",
&comment_uid.0
)
});
if let Err(error) = result {
error_sender.send(error).expect("Could not send error");
}
statistics.add_annotation();
});
})
});
if let Ok(error) = error_receiver.try_recv() {
Err(error)
} else {
annotations_to_upload.clear();
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
fn upload_annotations_from_reader(
client: &Client,
source: &Source,
annotations: impl BufRead,
statistics: &Statistics,
dataset_name: &DatasetFullName,
use_moon_forms: bool,
batch_size: usize,
pool: &mut Pool,
) -> Result<()> {
let mut annotations_to_upload = Vec::new();
for read_comment_result in read_annotations_iter(annotations, Some(statistics)) {
let new_comment = read_comment_result?;
if new_comment.has_annotations() {
annotations_to_upload.push(new_comment);
if annotations_to_upload.len() >= batch_size {
upload_batch_of_annotations(
&mut annotations_to_upload,
client,
source,
statistics,
dataset_name,
use_moon_forms,
pool,
)?;
}
}
}
if !annotations_to_upload.is_empty() {
upload_batch_of_annotations(
&mut annotations_to_upload,
client,
source,
statistics,
dataset_name,
use_moon_forms,
pool,
)?;
}
Ok(())
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct CommentIdComment {
pub id: CommentId,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct NewAnnotation {
pub comment: CommentIdComment,
#[serde(skip_serializing_if = "Option::is_none")]
pub labelling: Option<EitherLabelling>,
#[serde(skip_serializing_if = "Option::is_none")]
pub entities: Option<NewEntities>,
#[serde(skip_serializing_if = "should_skip_serializing_optional_vec", default)]
pub moon_forms: Option<Vec<NewMoonForm>>,
}
impl HasAnnotations for NewAnnotation {
fn has_annotations(&self) -> bool {
self.labelling.has_annotations()
|| self.entities.has_annotations()
|| self.moon_forms.has_annotations()
}
}
fn read_annotations_iter<'a>(
mut annotations: impl BufRead + 'a,
statistics: Option<&'a Statistics>,
) -> impl Iterator<Item = Result<NewAnnotation>> + 'a {
let mut line = String::new();
let mut line_number: u32 = 0;
std::iter::from_fn(move || {
line_number += 1;
line.clear();
let read_result = annotations
.read_line(&mut line)
.with_context(|| format!("Could not read line {line_number} from input stream"));
match read_result {
Ok(0) => return None,
Ok(bytes_read) => {
if let Some(s) = statistics {
s.add_bytes_read(bytes_read)
}
}
Err(e) => return Some(Err(e)),
}
Some(
serde_json::from_str::<NewAnnotation>(line.trim_end()).with_context(|| {
format!("Could not parse annotations at line {line_number} from input stream")
}),
)
})
}
#[derive(Debug)]
pub struct Statistics {
bytes_read: AtomicUsize,
annotations: AtomicUsize,
}
impl AnnotationStatistic for Statistics {
fn add_annotation(&self) {
self.annotations.fetch_add(1, Ordering::SeqCst);
}
}
impl Statistics {
fn new() -> Self {
Self {
bytes_read: AtomicUsize::new(0),
annotations: AtomicUsize::new(0),
}
}
#[inline]
fn add_bytes_read(&self, bytes_read: usize) {
self.bytes_read.fetch_add(bytes_read, Ordering::SeqCst);
}
#[inline]
fn bytes_read(&self) -> usize {
self.bytes_read.load(Ordering::SeqCst)
}
#[inline]
pub fn num_annotations(&self) -> usize {
self.annotations.load(Ordering::SeqCst)
}
}
fn basic_statistics(statistics: &Statistics) -> (u64, String) {
let bytes_read = statistics.bytes_read();
let num_annotations = statistics.num_annotations();
(
bytes_read as u64,
format!("{} {}", num_annotations, "annotations".dimmed(),),
)
}
fn progress_bar(total_bytes: u64, statistics: &Arc<Statistics>) -> Progress {
Progress::new(
basic_statistics,
statistics,
Some(total_bytes),
ProgressOptions { bytes_units: true },
)
}