use crate::printer::Printer;
use anyhow::{anyhow, bail, Context, Error, Result};
use log::info;
use reinfer_client::{
resources::{dataset::DatasetFlag, entity_def::NewGeneralFieldDef},
Client, DatasetFullName, NewDataset, NewEntityDef, NewLabelDef, NewLabelGroup,
SourceIdentifier,
};
use serde::Deserialize;
use std::str::FromStr;
use structopt::StructOpt;
#[derive(Debug, StructOpt)]
pub struct CreateDatasetArgs {
#[structopt(name = "owner-name/dataset-name")]
name: DatasetFullName,
#[structopt(long = "title")]
title: Option<String>,
#[structopt(long = "description")]
description: Option<String>,
#[structopt(
long = "has-sentiment",
help = "Enable sentiment prediction for the dataset [default: false]"
)]
has_sentiment: Option<bool>,
#[structopt(short = "s", long = "source")]
sources: Vec<SourceIdentifier>,
#[structopt(short = "e", long = "entity-defs", default_value = "[]")]
entity_defs: VecExt<NewEntityDef>,
#[structopt(short = "g", long = "general-fields", default_value = "[]")]
general_fields: VecExt<NewGeneralFieldDef>,
#[structopt(long = "label-defs", default_value = "[]")]
label_defs: VecExt<NewLabelDef>,
#[structopt(long = "label-groups", default_value = "[]")]
label_groups: VecExt<NewLabelGroup>,
#[structopt(long = "model-family")]
model_family: Option<String>,
#[structopt(long = "copy-annotations-from")]
copy_annotations_from: Option<String>,
#[structopt(long = "qos")]
qos: Option<bool>,
#[structopt(long = "external-llm")]
external_llm: Option<bool>,
#[structopt(long = "gen-ai")]
gen_ai: Option<bool>,
#[structopt(long = "zero-shot")]
zero_shot: Option<bool>,
}
pub fn create(client: &Client, args: &CreateDatasetArgs, printer: &Printer) -> Result<()> {
let CreateDatasetArgs {
name,
title,
description,
has_sentiment,
sources,
entity_defs,
general_fields,
label_defs,
label_groups,
model_family,
copy_annotations_from,
qos,
external_llm,
gen_ai,
zero_shot,
} = args;
let source_ids = {
let mut source_ids = Vec::with_capacity(sources.len());
for source in sources.iter() {
source_ids.push(
client
.get_source(source.clone())
.context("Operation to get source has failed")?
.id,
);
}
source_ids
};
let get_dataset_flags = || -> Result<Vec<DatasetFlag>> {
if external_llm.unwrap_or_default() && !gen_ai.unwrap_or_default() {
bail!("External Llm can only be used if gen ai features are enabled. Please add `--gen-ai true`")
}
if zero_shot.unwrap_or_default() && !gen_ai.unwrap_or_default() {
bail!("Zero shot can only be used if gen ai features are enabled. Please add `--gen-ai true`")
}
let mut dataset_flags = Vec::new();
if gen_ai.unwrap_or_default() {
dataset_flags.push(DatasetFlag::Gpt4)
}
if external_llm.unwrap_or_default() {
dataset_flags.push(DatasetFlag::ExternalMoonLlm)
}
if zero_shot.unwrap_or_default() {
dataset_flags.push(DatasetFlag::ZeroShotLabels)
}
if qos.unwrap_or_default() {
dataset_flags.push(DatasetFlag::Qos)
}
Ok(dataset_flags)
};
let entity_defs = &entity_defs.0;
let general_fields = &general_fields.0;
let label_groups = &label_groups.0;
let label_defs = match (!&label_defs.0.is_empty(), !label_groups.is_empty()) {
(true, false) => Some(&label_defs.0[..]),
_ => None,
};
let dataset = client
.create_dataset(
name,
NewDataset {
source_ids: &source_ids,
title: title.as_deref(),
description: description.as_deref(),
has_sentiment: Some(has_sentiment.unwrap_or(false)),
entity_defs: if entity_defs.is_empty() {
None
} else {
Some(entity_defs)
},
general_fields: if general_fields.is_empty() {
None
} else {
Some(general_fields)
},
label_defs,
label_groups: if label_groups.is_empty() {
None
} else {
Some(&label_groups[..])
},
model_family: model_family.as_deref(),
copy_annotations_from: copy_annotations_from.as_deref(),
dataset_flags: get_dataset_flags()?,
},
)
.context("Operation to create a dataset has failed.")?;
info!(
"New dataset `{}` [id: {}] created successfully",
dataset.full_name().0,
dataset.id.0,
);
printer.print_resources(&[dataset])?;
Ok(())
}
#[derive(Debug, Deserialize)]
struct VecExt<T>(pub Vec<T>);
impl<T: serde::de::DeserializeOwned> FromStr for VecExt<T> {
type Err = Error;
fn from_str(string: &str) -> Result<Self> {
serde_json::from_str(string).map_err(|source| {
anyhow!(
"Expected valid json for type. Got: '{}', which failed because: '{}'",
string.to_owned(),
source
)
})
}
}