#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs)]
use std::{io, sync::Arc};
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub mod engine;
mod analyzer;
pub use analyzer::*;
mod config;
pub use config::*;
mod request;
pub use request::*;
mod result;
pub use result::*;
mod rules;
pub use rules::*;
pub type Result<T> = std::result::Result<T, Error>;
pub type WarningResult<T, W = WarningsAsErrors> = Result<<W as WarningHandling>::OkType<T>>;
#[derive(Debug, Clone, Error)]
pub enum Error {
#[error("I/O error: {0}")]
Io(#[from] Arc<io::Error>),
#[error("serialization error: {0}")]
Serialization(#[from] Arc<serde_json::Error>),
#[error("stdin unavailable")]
StdinUnavailable,
#[error("stdout unavailable")]
StdoutUnavailable,
#[error("unserializable config value for key {0}")]
UnserializableConfig(String),
#[error("invalid request: {error}")]
KataGoGeneralError {
error: String,
},
#[error("invalid field {field}: {error}")]
KataGoFieldError {
error: String,
field: String,
},
#[error("unhandled warnings: {0:?}")]
UnhandledWarnings(Vec<Warning>),
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
Error::Io(Arc::new(e))
}
}
impl From<serde_json::Error> for Error {
fn from(e: serde_json::Error) -> Self {
Error::Serialization(Arc::new(e))
}
}
#[derive(Debug, Clone)]
pub struct Warning {
pub warning: String,
pub field: String,
}
pub trait WarningHandling {
type OkType<T>;
fn ok<T>(val: T) -> WarningResult<T, Self>;
fn set_result<T>(result: &mut WarningResult<T, Self>, value: T);
fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning);
fn merge<T, U, V>(
a: WarningResult<T, Self>,
b: WarningResult<U, Self>,
f: impl FnOnce(T, U) -> V,
) -> WarningResult<V, Self>;
}
#[derive(Debug, Default, Clone)]
pub struct WarningsAsErrors;
impl WarningHandling for WarningsAsErrors {
type OkType<T> = T;
fn ok<T>(val: T) -> WarningResult<T, Self> {
Ok(val)
}
fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
if let Ok(r) = result {
*r = value;
}
}
fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
match result {
Ok(_) => *result = Err(Error::UnhandledWarnings(vec![warning])),
Err(Error::UnhandledWarnings(warnings)) => warnings.push(warning),
_ => {}
}
}
fn merge<T, U, V>(
a: WarningResult<T, Self>,
b: WarningResult<U, Self>,
f: impl FnOnce(T, U) -> V,
) -> WarningResult<V, Self> {
Ok(f(a?, b?))
}
}
#[derive(Debug, Default, Clone)]
pub struct ReturnWarnings;
impl WarningHandling for ReturnWarnings {
type OkType<T> = MaybeWarnings<T>;
fn ok<T>(val: T) -> WarningResult<T, Self> {
Ok(MaybeWarnings {
value: val,
warnings: None,
})
}
fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
if let Ok(r) = result {
r.value = value;
}
}
fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
if let Ok(r) = result {
match r.warnings.as_mut() {
Some(warnings) => warnings.push(warning),
None => r.warnings = Some(vec![warning]),
}
}
}
fn merge<T, U, V>(
a: WarningResult<T, Self>,
b: WarningResult<U, Self>,
f: impl FnOnce(T, U) -> V,
) -> WarningResult<V, Self> {
let MaybeWarnings {
value: a,
warnings: a_warnings,
} = a?;
let MaybeWarnings {
value: b,
warnings: b_warnings,
} = b?;
Ok(MaybeWarnings {
value: f(a, b),
warnings: a_warnings.or(b_warnings),
})
}
}
#[derive(Debug, Default, Clone)]
pub struct MaybeWarnings<T> {
pub value: T,
pub warnings: Option<Vec<Warning>>,
}
impl<T> MaybeWarnings<T> {
pub fn into_result(self) -> Result<T> {
match self.warnings {
Some(warnings) => Err(Error::UnhandledWarnings(warnings)),
None => Ok(self.value),
}
}
pub fn inspect_warnings<F: FnOnce(&Vec<Warning>)>(self, f: F) -> Self {
if let Some(warnings) = &self.warnings {
f(warnings);
}
self
}
}
#[derive(Debug, Default, Clone)]
pub struct IgnoreWarnings;
impl WarningHandling for IgnoreWarnings {
type OkType<T> = T;
fn ok<T>(val: T) -> WarningResult<T, Self> {
Ok(val)
}
fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
if let Ok(r) = result {
*r = value;
}
}
fn add_warning<T>(_result: &mut WarningResult<T, Self>, _warning: Warning) {}
fn merge<T, U, V>(
a: WarningResult<T, Self>,
b: WarningResult<U, Self>,
f: impl FnOnce(T, U) -> V,
) -> WarningResult<V, Self> {
Ok(f(a?, b?))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Player {
#[serde(rename = "B")]
Black,
#[serde(rename = "W")]
White,
}
#[cfg(feature = "sgf-parse")]
impl From<sgf_parse::Color> for Player {
fn from(value: sgf_parse::Color) -> Self {
match value {
sgf_parse::Color::Black => Player::Black,
sgf_parse::Color::White => Player::White,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Coord(pub u8, pub u8);
impl Coord {
pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
let (x_part, y_part) = s
.chars()
.partition::<String, _>(|c| c.is_ascii_alphabetic());
let mut x = 0;
for c in x_part.chars().map(|c| c.to_ascii_uppercase()) {
x *= 25;
x += (c as u8) - b'A' + 1;
if c == 'I' {
return None;
} else if c > 'I' {
x -= 1;
}
}
x = x.checked_sub(1)?;
let y = height.checked_sub(y_part.parse::<u8>().ok()?)?;
Some(Self(x, y))
}
pub fn to_gtp(self, height: u8) -> String {
let Self(mut x, y) = self;
const LETTERS: &[u8; 25] = b"ABCDEFGHJKLMNOPQRSTUVWXYZ";
let mut gtp = String::with_capacity(4);
if x >= 25 {
gtp.push(LETTERS[(x / 25) as usize - 1] as char);
x %= 25;
}
gtp.push(LETTERS[x as usize] as char);
gtp.push_str(&(height - y).to_string());
gtp
}
}
#[cfg(feature = "sgf-parse")]
impl From<sgf_parse::go::Point> for Coord {
fn from(c: sgf_parse::go::Point) -> Self {
Self(c.x, c.y)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Move {
Move(Coord),
Pass,
}
impl Move {
pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
if s.eq_ignore_ascii_case("pass") {
Some(Self::Pass)
} else {
Coord::from_gtp(s, height).map(Self::Move)
}
}
pub fn to_gtp(self, height: u8) -> String {
match self {
Move::Move(coord) => coord.to_gtp(height),
Move::Pass => "pass".to_string(),
}
}
}
#[cfg(feature = "sgf-parse")]
impl From<sgf_parse::go::Move> for Move {
fn from(m: sgf_parse::go::Move) -> Self {
match m {
sgf_parse::go::Move::Move(p) => Move::Move(p.into()),
sgf_parse::go::Move::Pass => Move::Pass,
}
}
}
#[derive(Debug, Clone)]
pub struct VersionInfo {
pub version: String,
pub git_hash: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Model {
pub name: String,
pub internal_name: String,
pub max_batch_size: u32,
#[serde(rename = "usesHumanSLProfile")]
pub uses_humansl_profile: bool,
pub version: u32,
#[serde(rename = "usingFP16")]
pub using_fp16: Enabled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Enabled {
False,
True,
Auto,
}
#[cfg(test)]
mod tests {
use crate::Coord;
#[test]
fn coord_from_gtp() {
assert_eq!(Coord::from_gtp("A1", 19), Some(Coord(0, 18)));
assert_eq!(Coord::from_gtp("T19", 19), Some(Coord(18, 0)));
assert_eq!(Coord::from_gtp("I9", 19), None);
assert_eq!(Coord::from_gtp("1", 19), None);
assert_eq!(Coord::from_gtp("A", 19), None);
assert_eq!(Coord::from_gtp("A20", 19), None);
assert_eq!(Coord::from_gtp("Z1", 255), Some(Coord(24, 254)));
assert_eq!(Coord::from_gtp("AA1", 255), Some(Coord(25, 254)));
assert_eq!(Coord::from_gtp("BB1", 255), Some(Coord(51, 254)));
assert_eq!(Coord::from_gtp("JJ1", 255), Some(Coord(233, 254)));
}
#[test]
fn coord_to_gtp() {
assert_eq!(Coord(0, 18).to_gtp(19), "A1");
assert_eq!(Coord(18, 0).to_gtp(19), "T19");
assert_eq!(Coord(24, 254).to_gtp(255), "Z1");
assert_eq!(Coord(25, 254).to_gtp(255), "AA1");
assert_eq!(Coord(51, 254).to_gtp(255), "BB1");
assert_eq!(Coord(233, 254).to_gtp(255), "JJ1");
}
}