use std::{collections::HashMap, error::Error, fs::File, io::Read, sync::Arc, time::Duration};
use ordered_float::OrderedFloat;
use quick_xml::{de::from_str, DeError};
use serde::Deserialize;
use crate::{
CbsConfig, ConflictBasedSearch, Graph, GraphEdgeId, GraphNodeId, MyTime, SimpleEdgeData,
SimpleHeuristic, SimpleNodeData, SimpleState, SimpleWorld, Task,
};
pub fn get_cbs_from_files(
map_file: &str,
task_file: &str,
config_file: &str,
n_agents: usize,
n_threads: usize,
) -> (
Arc<Graph<SimpleNodeData, SimpleEdgeData>>,
ConflictBasedSearch<SimpleWorld, SimpleState, GraphEdgeId, MyTime, MyTime, SimpleHeuristic>,
CbsConfig<SimpleWorld, SimpleState, GraphEdgeId, MyTime, MyTime, SimpleHeuristic>,
f64,
) {
let (graph, tasks, config) = parse_inputs(map_file, task_file, config_file, n_agents).unwrap();
let transition_system = Arc::new(SimpleWorld::new(graph.clone(), config.agent_size));
(
graph,
ConflictBasedSearch::new(transition_system.clone()),
CbsConfig::new(
transition_system,
tasks,
OrderedFloat(config.precision),
n_threads,
Some(Duration::from_secs_f64(config.time_limit)),
),
config.agent_size,
)
}
pub fn read_from_file(filename: &str) -> Result<String, Box<dyn Error>> {
let mut file = File::open(filename)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
Ok(contents)
}
pub fn parse_inputs(
map_file: &str,
task_file: &str,
config_file: &str,
n_agents: usize,
) -> Result<
(
Arc<Graph<SimpleNodeData, SimpleEdgeData>>,
Vec<Arc<Task<SimpleState, MyTime>>>,
Config,
),
Box<dyn Error>,
> {
let contents = read_from_file(map_file)?;
let data: Result<Map, DeError> = from_str(&contents);
let map = data?;
let contents = read_from_file(task_file)?;
let data: Result<Scenario, DeError> = from_str(&contents);
let mut scenario = data?;
scenario.agents.truncate(n_agents);
let config = parse_config(config_file)?;
let mut graph = Graph::new();
let mut tasks = Vec::new();
if let Some(map) = map.grid {
let mut grid = vec![vec![GraphNodeId(0); map.width]; map.height];
for x in 0..map.width {
for y in 0..map.height {
if map.grid.rows[y][x] == 1 {
continue;
}
grid[y][x] = graph.add_node((x as f64, y as f64));
}
}
for x in 0..map.width {
for y in 0..map.height {
if map.grid.rows[y][x] == 1 {
continue;
}
let node_id = grid[y][x];
if x > 0 && map.grid.rows[y][x - 1] == 0 {
graph.add_edge(node_id, grid[y][x - 1], 1.0);
}
if y > 0 && map.grid.rows[y - 1][x] == 0 {
graph.add_edge(node_id, grid[y - 1][x], 1.0);
}
if x < map.width - 1 && map.grid.rows[y][x + 1] == 0 {
graph.add_edge(node_id, grid[y][x + 1], 1.0);
}
if y < map.height - 1 && map.grid.rows[y + 1][x] == 0 {
graph.add_edge(node_id, grid[y + 1][x], 1.0);
}
}
}
for agent in scenario.agents {
let initial_state = SimpleState(grid[agent.start_i.unwrap()][agent.start_j.unwrap()]);
let goal_state = SimpleState(grid[agent.goal_i.unwrap()][agent.goal_j.unwrap()]);
tasks.push(Arc::new(Task::new(
initial_state,
goal_state,
OrderedFloat(0.0),
)));
}
} else if let Some(map) = map.graph {
let mut nodes = HashMap::new();
for node in map.nodes {
let position = node
.position
.split(',')
.map(|n| n.parse().unwrap())
.collect::<Vec<f64>>();
nodes.insert(node.id, graph.add_node((position[0], position[1])));
}
for edge in map.edges {
graph.add_edge(nodes[&edge.source], nodes[&edge.target], edge.weight);
}
for agent in scenario.agents {
let start = nodes[&("n".to_string() + &agent.start_id.unwrap().to_string())];
let goal = nodes[&("n".to_string() + &agent.goal_id.unwrap().to_string())];
let initial_state = SimpleState(start);
let goal_state = SimpleState(goal);
tasks.push(Arc::new(Task::new(
initial_state,
goal_state,
OrderedFloat(0.0),
)));
}
} else {
return Err("No map found".into());
}
Ok((Arc::new(graph), tasks, config))
}
#[derive(Debug, Deserialize)]
struct Map {
#[serde(rename = "map")]
grid: Option<GridMap>,
graph: Option<GraphMap>,
}
#[derive(Debug, Deserialize)]
struct GridMap {
width: usize,
height: usize,
grid: Grid,
}
#[derive(Debug, Deserialize)]
struct Grid {
#[serde(rename = "row")]
rows: Vec<Vec<usize>>,
}
#[derive(Debug, Deserialize)]
struct GraphMap {
#[serde(rename = "node")]
nodes: Vec<Node>,
#[serde(rename = "edge")]
edges: Vec<Edge>,
}
#[derive(Debug, Deserialize)]
struct Node {
#[serde(rename = "@id")]
id: String,
#[serde(rename = "data")]
position: String,
}
#[derive(Debug, Deserialize)]
struct Edge {
#[serde(rename = "@source")]
source: String,
#[serde(rename = "@target")]
target: String,
#[serde(rename = "data")]
weight: f64,
}
#[derive(Debug, Deserialize)]
struct Scenario {
#[serde(rename = "agent")]
pub agents: Vec<Agent>,
}
#[derive(Debug, Deserialize)]
struct Agent {
#[serde(rename = "@start_i")]
pub start_i: Option<usize>,
#[serde(rename = "@start_j")]
pub start_j: Option<usize>,
#[serde(rename = "@goal_i")]
pub goal_i: Option<usize>,
#[serde(rename = "@goal_j")]
pub goal_j: Option<usize>,
#[serde(rename = "@start_id")]
pub start_id: Option<usize>,
#[serde(rename = "@goal_id")]
pub goal_id: Option<usize>,
}
pub fn parse_config(filename: &str) -> Result<Config, Box<dyn Error>> {
let contents = read_from_file(filename)?;
let data: Result<ConfigRoot, DeError> = from_str(&contents);
let config = data?.config;
Ok(config)
}
#[derive(Debug, Deserialize)]
struct ConfigRoot {
#[serde(rename = "algorithm")]
pub config: Config,
}
#[derive(Debug, Deserialize)]
pub struct Config {
pub agent_size: f64,
pub connectedness: usize,
pub precision: f64,
#[serde(rename = "timelimit")]
pub time_limit: f64,
}