use crate::printer::Printer;
use anyhow::{anyhow, Context, Error, Result};
use log::info;
use reinfer_client::{
Client, DatasetFullName, NewDataset, NewEntityDef, NewLabelDef, NewLabelGroup, SourceIdentifier,
};
use serde::Deserialize;
use std::str::FromStr;
use structopt::StructOpt;
#[derive(Debug, StructOpt)]
pub struct CreateDatasetArgs {
#[structopt(name = "owner-name/dataset-name")]
name: DatasetFullName,
#[structopt(long = "title")]
title: Option<String>,
#[structopt(long = "description")]
description: Option<String>,
#[structopt(
long = "has-sentiment",
help = "Enable sentiment prediction for the dataset [default: false]"
)]
has_sentiment: Option<bool>,
#[structopt(short = "s", long = "source")]
sources: Vec<SourceIdentifier>,
#[structopt(short = "e", long = "entity-defs", default_value = "[]")]
entity_defs: VecExt<NewEntityDef>,
#[structopt(long = "label-defs", default_value = "[]")]
label_defs: VecExt<NewLabelDef>,
#[structopt(long = "label-groups", default_value = "[]")]
label_groups: VecExt<NewLabelGroup>,
#[structopt(long = "model-family")]
model_family: Option<String>,
#[structopt(long = "copy-annotations-from")]
copy_annotations_from: Option<String>,
}
pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> Result<()> {
let CreateDatasetArgs {
name,
title,
description,
has_sentiment,
sources,
entity_defs,
label_defs,
label_groups,
model_family,
copy_annotations_from,
} = args;
let source_ids = {
let mut source_ids = Vec::with_capacity(sources.len());
for source in sources.iter() {
source_ids.push(
client
.get_source(source.clone())
.context("Operation to get source has failed")?
.id,
);
}
source_ids
};
let entity_defs = &entity_defs.0;
let label_groups = &label_groups.0;
let label_defs = match (!&label_defs.0.is_empty(), !label_groups.is_empty()) {
(true, false) => Some(&label_defs.0[..]),
_ => None,
};
let dataset = client
.create_dataset(
name,
NewDataset {
source_ids: &source_ids,
title: title.as_deref(),
description: description.as_deref(),
has_sentiment: Some(has_sentiment.unwrap_or(false)),
entity_defs: if entity_defs.is_empty() {
None
} else {
Some(entity_defs)
},
label_defs,
label_groups: if label_groups.is_empty() {
None
} else {
Some(&label_groups[..])
},
model_family: model_family.as_deref(),
copy_annotations_from: copy_annotations_from.as_deref(),
},
)
.context("Operation to create a dataset has failed.")?;
info!(
"New dataset `{}` [id: {}] created successfully",
dataset.full_name().0,
dataset.id.0,
);
printer.print_resources(&[dataset])?;
Ok(())
}
#[derive(Debug, Deserialize)]
struct VecExt<T>(pub Vec<T>);
impl<T: serde::de::DeserializeOwned> FromStr for VecExt<T> {
type Err = Error;
fn from_str(string: &str) -> Result<Self> {
serde_json::from_str(string).map_err(|source| {
anyhow!(
"Expected valid json for type. Got: '{}', which failed because: '{}'",
string.to_owned(),
source
)
})
}
}