twgame-get 0.0.1

Get resources for TwGame replays
Documentation
//! Load maps once and use them for the livetime of the program

use crate::Sha256Hash;
use sha2::{Digest, Sha256};
use std::{
    borrow::Cow,
    collections::{HashMap, HashSet},
    fs, io,
    sync::{Arc, Mutex, RwLock},
};
use twgame::Map;

pub struct MapServiceOpt {
    pub cache_directory: String,
    pub maps_url: Vec<String>,
    pub offline: bool,
}

pub struct MapService {
    inner: RwLock<MapServiceInner>,
    load_map: Mutex<LoadMap>,
}

#[derive(Copy, Clone)]
enum Get<'a> {
    Sha256(Sha256Hash),
    Path(&'a str),
}

impl<'a> Get<'a> {
    fn path(&self, directory: &str) -> Cow<'a, str> {
        match self {
            Self::Path(path) => Cow::Borrowed(path),
            Self::Sha256(sha256) => {
                let sha256_str = hex::encode(sha256);
                let map_file = format!("{directory}/{sha256_str}.map");
                Cow::Owned(map_file)
            }
        }
    }
    fn sha256(self) -> Option<Sha256Hash> {
        match self {
            Self::Path(_) => None,
            Self::Sha256(sha256) => Some(sha256),
        }
    }
}

impl MapService {
    pub fn new(opt: MapServiceOpt) -> Arc<Self> {
        Arc::new(Self {
            inner: RwLock::new(MapServiceInner::new()),
            load_map: Mutex::new(LoadMap {
                directory: opt.cache_directory,
                urls: opt.maps_url,
                offline: opt.offline,
            }),
        })
    }

    pub fn get_sha256(&self, sha256: Sha256Hash) -> (CachedMap, Sha256Hash) {
        self.get_inner(Get::Sha256(sha256))
    }

    pub fn get_path(&self, path: &str) -> (CachedMap, Sha256Hash) {
        self.get_inner(Get::Path(path))
    }

    fn get_inner(&self, get: Get) -> (CachedMap, Sha256Hash) {
        {
            let service = self.inner.read().unwrap();
            if let Some(map) = service.get(get) {
                return map;
            }
        }
        // retrieve from disk / network hold lock so that we only retrieve each map at max once
        let load_lock = match self.load_map.lock() {
            Ok(l) => l,
            // we currently panic, e.g. when the map doesn't exist. Keep going anyway in the multi-threaded context (tee-hee run)
            Err(e) => e.into_inner(),
        };
        // check whether the map was loaded while waiting for the lock
        {
            let service = self.inner.read().unwrap();
            if let Some(map) = service.get(get) {
                return map;
            }
        }
        match load_lock.load_map(get) {
            Ok((raw_map, sha256)) => {
                let mut parsed_map =
                    twmap::TwMap::parse(raw_map.as_ref()).expect("Failed to load map with TwMap");
                let map = Arc::new(
                    twgame::Map::try_from(&mut parsed_map).expect("Failed to load TwGame-Map"),
                );
                let cached_map = CachedMap {
                    map: Arc::clone(&map),
                    raw: Arc::clone(&raw_map),
                };
                // write map back into cache service
                let mut inner = self.inner.write().unwrap();
                assert!(
                    inner.sha256.insert(sha256, cached_map.clone()).is_none(),
                    "Shouldn't load maps twice from disk/network"
                );
                if let Get::Path(path) = get {
                    inner
                        .path
                        .insert(path.to_owned(), (cached_map.clone(), sha256));
                }
                (cached_map, sha256)
            }
            Err(sha256) => {
                {
                    let mut inner = self.inner.write().unwrap();
                    assert!(
                        inner.sha256_missing.insert(sha256),
                        "Shouldn't try to load map twice from disk/network"
                    );
                }
                panic!(
                    "Map not in map archive (https://gitlab.com/ddnet-rs/ddnet-map-archive/). Use --map to specify path for map to use"
                );
            }
        }
    }

    pub fn rename_to_sha256(in_path: &str, out_path: &str) {
        fs::create_dir_all(out_path).expect("failed to create maps directory");

        let maps = walkdir::WalkDir::new(in_path)
            .into_iter()
            // filter out errors
            .filter_map(|e| e.ok())
            .filter_map(|e| e.path().to_str().map(str::to_owned))
            .filter(|e| e.ends_with(".map"));

        for path in maps {
            let (map, sha256) = LoadMap::load_from_disk(&path).unwrap();

            let sha256_str = hex::encode(sha256);
            let map_file = format!("{out_path}/{sha256_str}.map");
            fs::write(map_file, map.as_ref()).expect("failed to write map file to disk");
        }
    }
}

#[derive(Clone)]
pub struct CachedMap {
    pub map: Arc<Map>,
    pub raw: Arc<[u8]>,
}

struct MapServiceInner {
    sha256: HashMap<Sha256Hash, CachedMap>,
    sha256_missing: HashSet<Sha256Hash>,
    path: HashMap<String, (CachedMap, Sha256Hash)>,
}

impl MapServiceInner {
    fn new() -> Self {
        Self {
            sha256: HashMap::new(),
            sha256_missing: HashSet::new(),
            path: HashMap::new(),
        }
    }

    fn get(&self, get: Get) -> Option<(CachedMap, Sha256Hash)> {
        match get {
            Get::Sha256(sha256) => self
                .sha256
                .get(&sha256)
                .map(|m| (m.clone(), sha256))
                .or_else(|| {
                    if self.sha256_missing.contains(&sha256) {
                        panic!("Map not in map-archive (previous try returned 404)");
                    }
                    None
                }),
            Get::Path(path) => self.path.get(path).cloned(),
        }
    }
}

struct LoadMap {
    directory: String,
    // TODO: try to download from external archive
    #[allow(dead_code)]
    urls: Vec<String>,
    offline: bool,
}

impl LoadMap {
    fn load_from_disk(path: &str) -> Result<(Arc<[u8]>, Sha256Hash), io::Error> {
        let bytes: Vec<u8> = std::fs::read(path)?;
        let mut hasher = Sha256::new();
        hasher.update(&bytes);
        let map_sha256 = hasher.finalize();
        Ok((Arc::from(bytes), map_sha256.into()))
    }

    /// Panics: When unable to retrieve map, unless a 404 from map_archive
    fn load_map(&self, get: Get) -> Result<(Arc<[u8]>, Sha256Hash), Sha256Hash> {
        let map_file = get.path(&self.directory);
        // check whether we can load the map from disk
        match Self::load_from_disk(&map_file) {
            Ok(map) => return Ok(map),
            Err(err) => {
                // only panic if file was requested. Silently drop error for sha256
                if matches!(get, Get::Path(_)) {
                    panic!("Error loading map file: {err}");
                }
            }
        }

        if self.offline {
            panic!("Couln't find map locally. Network load disabled");
        }

        // check whether we can load the map from network
        let sha256 = get
            .sha256()
            .expect("Should not try loading from network when path was given");

        let sha256_str = hex::encode(sha256);
        let url = format!(
            "https://gitlab.com/ddnet-rs/ddnet-map-archive/-/raw/master/sha256/{sha256_str}.map"
        );

        let request = reqwest::blocking::get(&url)
            .expect("unable to download files, are you connected to the internet?");
        if request.status().as_u16() == 404 {
            // Not in archive. Need to store permanent exception
            return Err(sha256);
        }
        assert_eq!(
            request.status().as_u16(),
            200,
            "Map-Archive Server Error, is {url} reachable?"
        );
        let response = request
            .bytes()
            .expect("unable to download map from {url} are you connected to the internet?");

        // verify sha256sum
        let mut hasher = Sha256::new();
        hasher.update(&response);
        let response_sha256: Sha256Hash = hasher.finalize().into();
        assert_eq!(
            response_sha256, sha256,
            "downloaded map has incorrect sha256sum"
        );

        let response_arc: Arc<[u8]> = response.to_vec().into();

        // write to disk
        fs::create_dir_all(&self.directory).expect("failed to create maps directory");
        fs::write(map_file.as_ref(), response_arc.as_ref())
            .expect("failed to write map file to disk");
        Ok((response_arc, sha256))
    }
}