extern crate clap;
extern crate serde_json;
use clap::ArgMatches;
use crate::dachshund::beam::{Beam, BeamSearchResult};
use crate::dachshund::error::{CLQError, CLQResult};
use crate::dachshund::graph_base::GraphBase;
use crate::dachshund::graph_builder_base::GraphBuilderBase;
use crate::dachshund::id_types::{GraphId, NodeTypeId};
use crate::dachshund::line_processor::LineProcessorBase;
use crate::dachshund::non_core_type_ids::NonCoreTypeIds;
use crate::dachshund::row::{CliqueRow, EdgeRow, Row};
use crate::dachshund::search_problem::SearchProblem;
use crate::dachshund::transformer_base::TransformerBase;
use crate::dachshund::typed_graph::{LabeledGraph, TypedGraph};
use crate::dachshund::typed_graph_builder::TypedGraphBuilder;
use crate::dachshund::typed_graph_line_processor::TypedGraphLineProcessor;
use std::rc::Rc;
use std::sync::mpsc::Sender;
use std::sync::Arc;
pub struct Transformer {
pub core_type: String,
pub non_core_type_ids: Rc<NonCoreTypeIds>,
pub non_core_types: Rc<Vec<String>>,
pub edge_types: Rc<Vec<String>>,
pub num_non_core_types: usize,
pub line_processor: Arc<TypedGraphLineProcessor>,
pub search_problem: Rc<SearchProblem>,
pub debug: bool,
pub long_format: bool,
edge_rows: Vec<EdgeRow>,
clique_rows: Vec<CliqueRow>,
}
impl TransformerBase for Transformer {
fn get_line_processor(&self) -> Arc<dyn LineProcessorBase> {
self.line_processor.clone()
}
fn process_row(&mut self, row: Box<dyn Row>) -> CLQResult<()> {
if let Some(edge_row) = row.as_edge_row() {
self.edge_rows.push(edge_row);
}
if let Some(clique_row) = row.as_clique_row() {
self.clique_rows.push(clique_row);
}
Ok(())
}
fn reset(&mut self) -> CLQResult<()> {
self.edge_rows.clear();
self.clique_rows.clear();
Ok(())
}
fn process_batch(
&mut self,
graph_id: GraphId,
output: &Sender<(Option<String>, bool)>,
) -> CLQResult<()> {
let drained_rows = self.edge_rows.drain(..).collect::<Vec<_>>();
let graph: TypedGraph = self.build_pruned_graph(graph_id, drained_rows)?;
self.process_clique_rows(
&graph,
&self.clique_rows,
graph_id,
self.debug,
output,
)?;
Ok(())
}
}
impl Transformer {
pub fn process_typespec(
typespec: Vec<Vec<String>>,
core_type: &str,
non_core_types: Vec<String>,
) -> CLQResult<NonCoreTypeIds> {
let mut non_core_type_ids = NonCoreTypeIds::new();
non_core_type_ids.insert(core_type, NodeTypeId::from(0_usize));
let should_be_only_this_core_type = &typespec[0][0].clone();
for (non_core_type_ix, non_core_type) in non_core_types.iter().enumerate() {
non_core_type_ids.insert(non_core_type, NodeTypeId::from(non_core_type_ix + 1));
}
for item in typespec {
let core_type = &item[0];
let non_core_type = &item[2];
assert_eq!(core_type, should_be_only_this_core_type);
let non_core_type_id: &mut NodeTypeId = non_core_type_ids.require_mut(non_core_type)?;
non_core_type_id.increment_possible_edge_count();
}
Ok(non_core_type_ids)
}
#[allow(clippy::too_many_arguments)]
pub fn new(
typespec: Vec<Vec<String>>,
beam_size: usize,
alpha: f32,
global_thresh: Option<f32>,
local_thresh: Option<f32>,
num_to_search: usize,
num_epochs: usize,
max_repeated_prior_scores: usize,
debug: bool,
min_degree: usize,
core_type: String,
long_format: bool,
) -> CLQResult<Self> {
let search_problem = Rc::new(SearchProblem::new(
beam_size,
alpha,
global_thresh,
local_thresh,
num_to_search,
num_epochs,
max_repeated_prior_scores,
min_degree,
));
let mut edge_types_v: Vec<String> = typespec.iter().map(|x| x[1].clone()).collect();
edge_types_v.sort();
let edge_types = Rc::new(edge_types_v);
let mut non_core_types_v: Vec<String> = typespec.iter().map(|x| x[2].clone()).collect();
non_core_types_v.sort();
let non_core_types = Rc::new(non_core_types_v);
let num_non_core_types: usize = non_core_types.len();
let non_core_type_ids: Rc<NonCoreTypeIds> = Rc::new(Transformer::process_typespec(
typespec,
&core_type,
non_core_types.to_vec(),
)?);
let line_processor = Arc::new(TypedGraphLineProcessor::new(
core_type.clone(),
non_core_type_ids.clone(),
non_core_types.clone(),
edge_types.clone(),
));
let transformer = Self {
core_type,
non_core_type_ids,
non_core_types,
edge_types,
num_non_core_types,
line_processor,
search_problem,
debug,
long_format,
edge_rows: Vec::new(),
clique_rows: Vec::new(),
};
Ok(transformer)
}
pub fn from_argmatches(matches: ArgMatches) -> CLQResult<Self> {
let arg_value = |name: &str| -> CLQResult<&str> {
matches
.value_of(name)
.ok_or_else(|| CLQError::from(format!("Missing required argument: {name}")))
};
let typespec_str: &str = arg_value("typespec")?;
let typespec: Vec<Vec<String>> = serde_json::from_str(typespec_str)?;
let beam_size: usize = arg_value("beam_size")?.parse::<usize>()?;
let alpha: f32 = arg_value("alpha")?.parse::<f32>()?;
let global_thresh: Option<f32> = Some(arg_value("global_thresh")?.parse::<f32>()?);
let local_thresh: Option<f32> = Some(arg_value("local_thresh")?.parse::<f32>()?);
let num_to_search: usize = arg_value("num_to_search")?.parse::<usize>()?;
let num_epochs: usize = arg_value("epochs")?.parse::<usize>()?;
let max_repeated_prior_scores: usize =
arg_value("max_repeated_prior_scores")?.parse::<usize>()?;
let debug: bool = arg_value("debug_mode")?.parse::<bool>()?;
let min_degree: usize = arg_value("min_degree")?.parse::<usize>()?;
let core_type: String = arg_value("core_type")?.parse::<String>()?;
let long_format: bool = arg_value("long_format")?.parse::<bool>()?;
let transformer = Transformer::new(
typespec,
beam_size,
alpha,
global_thresh,
local_thresh,
num_to_search,
num_epochs,
max_repeated_prior_scores,
debug,
min_degree,
core_type,
long_format,
)?;
Ok(transformer)
}
#[allow(clippy::ptr_arg)]
pub fn build_pruned_graph(
&self,
graph_id: GraphId,
rows: Vec<EdgeRow>,
) -> CLQResult<TypedGraph> {
TypedGraphBuilder {
graph_id,
min_degree: Some(self.search_problem.min_degree),
}
.from_vector(rows)
}
pub fn process_graph<'a>(
&'a self,
graph: &'a TypedGraph,
clique_rows: &'a Vec<CliqueRow>,
graph_id: GraphId,
verbose: bool,
) -> CLQResult<BeamSearchResult<'a, TypedGraph>> {
let mut beam: Beam<TypedGraph> = Beam::new(
graph,
clique_rows,
verbose,
&self.non_core_types,
self.search_problem.clone(),
graph_id,
)?;
beam.run_search()
}
pub fn process_clique_rows<'a>(
&'a self,
graph: &'a TypedGraph,
clique_rows: &'a Vec<CliqueRow>,
graph_id: GraphId,
verbose: bool,
output: &Sender<(Option<String>, bool)>,
) -> CLQResult<Option<BeamSearchResult<'a, TypedGraph>>> {
if graph.get_core_ids().is_empty() || graph.get_non_core_ids().unwrap().is_empty() {
output.send((None, false)).unwrap();
return Ok(None);
}
let result: BeamSearchResult<TypedGraph> =
self.process_graph(graph, clique_rows, graph_id, verbose)?;
if result.top_candidate.get_score()? > 0.0 {
if !self.long_format {
let line: String = format!(
"{}\t{}",
graph_id.value(),
result
.top_candidate
.to_printable_row(&self.non_core_types, graph.get_reverse_labels_map())?,
);
output.send((Some(line), false)).unwrap();
} else {
result.top_candidate.print(
graph_id,
&self.non_core_types,
&self.core_type,
output,
)?;
}
}
Ok(Some(result))
}
}