katago_analysis/lib.rs
1//! A wrapper around KataGo's analysis protocol.
2//!
3//! See [KataGo Parallel Analysis Engine](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md)
4//! for official documentation of the analysis engine.
5//!
6//! Note: The asynchronous methods in this library must be called from within a Tokio runtime.
7//!
8//! # Examples
9//!
10//! After launching an [`Engine`](engine::Engine), the primary entry point for using this library is [`Analyzer`].
11//!
12//! ```
13//! use katago_analysis::{
14//! AnalysisRequest, Analyzer, Coord, Move, Player, Result, Rules,
15//! engine::{Engine, LaunchOptions},
16//! };
17//!
18//! async fn example(
19//! katago_path: String,
20//! analysis_config_path: String,
21//! model_path: String,
22//! ) -> Result<()> {
23//! let options = LaunchOptions::new(katago_path, analysis_config_path, model_path);
24//! let mut analyzer: Analyzer = Engine::launch(&options)?.into();
25//!
26//! let request = AnalysisRequest::new(
27//! Rules::chinese(),
28//! 19,
29//! 19,
30//! vec![
31//! (Player::Black, Move::Move(Coord(15, 3))),
32//! (Player::White, Move::Move(Coord(3, 15))),
33//! ],
34//! );
35//!
36//! let results = analyzer.analyze_game(request).await?;
37//! for i in 0..results.len() {
38//! println!(
39//! "Move {i}: {:.1}%",
40//! results.get(&i).unwrap().root_info.winrate * 100.0
41//! );
42//! }
43//! Ok(())
44//! }
45//! ```
46
47#![cfg_attr(docsrs, feature(doc_cfg))]
48#![warn(missing_docs)]
49
50use std::{io, sync::Arc};
51
52use serde::{Deserialize, Serialize};
53use thiserror::Error;
54
55pub mod engine;
56
57mod analyzer;
58pub use analyzer::*;
59
60mod config;
61pub use config::*;
62
63mod request;
64pub use request::*;
65
66mod result;
67pub use result::*;
68
69mod rules;
70pub use rules::*;
71
72/// The type of results returned by methods in this library.
73pub type Result<T> = std::result::Result<T, Error>;
74
75/// The type of results which may contain warnings returned by the analysis engine.
76///
77/// See also: [`WarningHandling`]
78pub type WarningResult<T, W = WarningsAsErrors> = Result<<W as WarningHandling>::OkType<T>>;
79
80/// Errors that can occur while interacting with the analysis engine.
81#[derive(Debug, Clone, Error)]
82pub enum Error {
83 /// An I/O error occurred while launching, writing to, or reading from the analysis engine.
84 #[error("I/O error: {0}")]
85 Io(#[from] Arc<io::Error>),
86
87 /// An error occurred while serializing or deserializing a message.
88 #[error("serialization error: {0}")]
89 Serialization(#[from] Arc<serde_json::Error>),
90
91 /// The engine's stdin was unavailable after launch.
92 #[error("stdin unavailable")]
93 StdinUnavailable,
94
95 /// The engine's stdout was unavailable after launch.
96 #[error("stdout unavailable")]
97 StdoutUnavailable,
98
99 /// The specified config setting's value could not be serialized.
100 #[error("unserializable config value for key {0}")]
101 UnserializableConfig(String),
102
103 /// The analysis engine returned an error response without specifying which request caused it.
104 ///
105 /// When this error occurs, all pending requests will return this error, even if they might have otherwise
106 /// succeeded.
107 #[error("invalid request: {error}")]
108 KataGoGeneralError {
109 /// The error message provided by KataGo.
110 error: String,
111 },
112
113 /// The analysis engine returned an error response.
114 ///
115 /// When this error occurs, all positions still being analyzed as part of the associated request will return this
116 /// error, even if they might have otherwise succeeded.
117 #[error("invalid field {field}: {error}")]
118 KataGoFieldError {
119 /// The error message provided by KataGo.
120 error: String,
121
122 /// The request field that caused the error.
123 field: String,
124 },
125
126 /// The analysis engine returned a warning response which was converted to an error.
127 ///
128 /// See also: [`WarningHandling`]
129 #[error("unhandled warnings: {0:?}")]
130 UnhandledWarnings(Vec<Warning>),
131}
132
133impl From<io::Error> for Error {
134 fn from(e: io::Error) -> Self {
135 Error::Io(Arc::new(e))
136 }
137}
138
139impl From<serde_json::Error> for Error {
140 fn from(e: serde_json::Error) -> Self {
141 Error::Serialization(Arc::new(e))
142 }
143}
144
145/// A warning returned by the analysis engine.
146#[derive(Debug, Clone)]
147pub struct Warning {
148 /// The warning message provided by KataGo.
149 pub warning: String,
150
151 /// The request field that caused the warning.
152 pub field: String,
153}
154
155/// Specifies how warnings from the analysis engine should be handled.
156///
157/// # Warning Handling
158///
159/// [`Analyzer`] can handle analysis engine warnings in several ways. The `Ok` result type of methods that can produce
160/// warnings will vary depending on which strategy is chosen.
161///
162/// - [`Analyzer<WarningsAsErrors>`](WarningsAsErrors) will return the successful result directly, or return [`Error::UnhandledWarnings`]
163/// when a warning occurs. This is the default strategy.
164/// - [`Analyzer<ReturnWarnings>`](ReturnWarnings) will return a result wrapped in a [`MaybeWarnings`].
165/// - [`Analyzer<IgnoreWarnings>`](IgnoreWarnings) will return the successful result directly, and ignore any warnings that occur.
166/// This is not recommended unless you expect warnings to occur and don't intend to handle them.
167///
168/// # Examples
169///
170/// ### [`WarningsAsErrors`]
171///
172/// ```
173/// # use katago_analysis::*;
174/// async fn example(
175/// analyzer: &mut Analyzer,
176/// request: AnalysisRequest
177/// ) -> Result<()> {
178/// let result: AnalysisResult = analyzer
179/// .analyze(request)
180/// .await? // Warnings will cause this to fail with an error
181/// .unwrap();
182/// println!("{:.1}%", result.root_info.winrate * 100.0);
183/// Ok(())
184/// }
185/// ```
186///
187/// ### [`ReturnWarnings`]
188///
189/// ```
190/// # use katago_analysis::*;
191/// async fn example(
192/// analyzer: &mut Analyzer<ReturnWarnings>,
193/// request: AnalysisRequest,
194/// ) -> Result<()> {
195/// let result: AnalysisResult = analyzer
196/// .analyze(request)
197/// .await?
198/// .inspect_warnings(|warnings| {
199/// for warning in warnings {
200/// println!("{warning:?}");
201/// }
202/// })
203/// .value
204/// .unwrap();
205/// println!("{:.1}%", result.root_info.winrate * 100.0);
206/// Ok(())
207/// }
208/// ```
209///
210/// Warnings can also be converted back to errors:
211///
212/// ```
213/// # use katago_analysis::*;
214/// async fn example(
215/// analyzer: &mut Analyzer<ReturnWarnings>,
216/// request: AnalysisRequest,
217/// ) -> Result<()> {
218/// let result: AnalysisResult = analyzer
219/// .analyze(request)
220/// .await?
221/// .into_result()? // Warnings will cause this to fail with an error
222/// .unwrap();
223/// println!("{:.1}%", result.root_info.winrate * 100.0);
224/// Ok(())
225/// }
226/// ```
227///
228/// ### [`IgnoreWarnings`]
229///
230/// ```
231/// # use katago_analysis::*;
232/// async fn example(
233/// analyzer: &mut Analyzer<IgnoreWarnings>,
234/// request: AnalysisRequest
235/// ) -> Result<()> {
236/// let result: AnalysisResult = analyzer
237/// .analyze(request)
238/// .await? // Warnings will be ignored
239/// .unwrap();
240/// println!("{:.1}%", result.root_info.winrate * 100.0);
241/// Ok(())
242/// }
243/// ```
244pub trait WarningHandling {
245 /// The `Ok` result type.
246 type OkType<T>;
247
248 /// Creates a successful result containing the given value and no warnings.
249 fn ok<T>(val: T) -> WarningResult<T, Self>;
250
251 /// Updates the successful result, preserving errors and warnings.
252 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T);
253
254 /// Adds a new warning to the result.
255 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning);
256
257 /// Applies a function to the successful values of two results, merging warnings and errors.
258 fn merge<T, U, V>(
259 a: WarningResult<T, Self>,
260 b: WarningResult<U, Self>,
261 f: impl FnOnce(T, U) -> V,
262 ) -> WarningResult<V, Self>;
263}
264
265/// Warnings will return [`Error::UnhandledWarnings`].
266///
267/// See also: [`WarningHandling`]
268#[derive(Debug, Default, Clone)]
269pub struct WarningsAsErrors;
270
271impl WarningHandling for WarningsAsErrors {
272 type OkType<T> = T;
273
274 fn ok<T>(val: T) -> WarningResult<T, Self> {
275 Ok(val)
276 }
277
278 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
279 if let Ok(r) = result {
280 *r = value;
281 }
282 }
283
284 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
285 match result {
286 Ok(_) => *result = Err(Error::UnhandledWarnings(vec![warning])),
287 Err(Error::UnhandledWarnings(warnings)) => warnings.push(warning),
288 _ => {}
289 }
290 }
291
292 fn merge<T, U, V>(
293 a: WarningResult<T, Self>,
294 b: WarningResult<U, Self>,
295 f: impl FnOnce(T, U) -> V,
296 ) -> WarningResult<V, Self> {
297 Ok(f(a?, b?))
298 }
299}
300
301/// Warnings will be returned in a [`MaybeWarnings`].
302///
303/// See also: [`WarningHandling`]
304#[derive(Debug, Default, Clone)]
305pub struct ReturnWarnings;
306
307impl WarningHandling for ReturnWarnings {
308 type OkType<T> = MaybeWarnings<T>;
309
310 fn ok<T>(val: T) -> WarningResult<T, Self> {
311 Ok(MaybeWarnings {
312 value: val,
313 warnings: None,
314 })
315 }
316
317 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
318 if let Ok(r) = result {
319 r.value = value;
320 }
321 }
322
323 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
324 if let Ok(r) = result {
325 match r.warnings.as_mut() {
326 Some(warnings) => warnings.push(warning),
327 None => r.warnings = Some(vec![warning]),
328 }
329 }
330 }
331
332 fn merge<T, U, V>(
333 a: WarningResult<T, Self>,
334 b: WarningResult<U, Self>,
335 f: impl FnOnce(T, U) -> V,
336 ) -> WarningResult<V, Self> {
337 let MaybeWarnings {
338 value: a,
339 warnings: a_warnings,
340 } = a?;
341 let MaybeWarnings {
342 value: b,
343 warnings: b_warnings,
344 } = b?;
345 Ok(MaybeWarnings {
346 value: f(a, b),
347 warnings: a_warnings.or(b_warnings),
348 })
349 }
350}
351
352/// A result that may contain warnings.
353#[derive(Debug, Default, Clone)]
354pub struct MaybeWarnings<T> {
355 /// The successful result value.
356 pub value: T,
357
358 /// The list of warnings that occurred, if any.
359 pub warnings: Option<Vec<Warning>>,
360}
361
362impl<T> MaybeWarnings<T> {
363 /// Extracts the successful result value, returning [`Error::UnhandledWarnings`] if any warnings occurred.
364 pub fn into_result(self) -> Result<T> {
365 match self.warnings {
366 Some(warnings) => Err(Error::UnhandledWarnings(warnings)),
367 None => Ok(self.value),
368 }
369 }
370
371 /// Calls a function with a reference to the warnings, if any.
372 pub fn inspect_warnings<F: FnOnce(&Vec<Warning>)>(self, f: F) -> Self {
373 if let Some(warnings) = &self.warnings {
374 f(warnings);
375 }
376 self
377 }
378}
379
380/// Warnings will be dropped.
381///
382/// See also: [`WarningHandling`]
383#[derive(Debug, Default, Clone)]
384pub struct IgnoreWarnings;
385
386impl WarningHandling for IgnoreWarnings {
387 type OkType<T> = T;
388
389 fn ok<T>(val: T) -> WarningResult<T, Self> {
390 Ok(val)
391 }
392
393 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
394 if let Ok(r) = result {
395 *r = value;
396 }
397 }
398
399 fn add_warning<T>(_result: &mut WarningResult<T, Self>, _warning: Warning) {}
400
401 fn merge<T, U, V>(
402 a: WarningResult<T, Self>,
403 b: WarningResult<U, Self>,
404 f: impl FnOnce(T, U) -> V,
405 ) -> WarningResult<V, Self> {
406 Ok(f(a?, b?))
407 }
408}
409
410/// Player colours.
411#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
412pub enum Player {
413 /// The black player.
414 #[serde(rename = "B")]
415 Black,
416 /// The white player.
417 #[serde(rename = "W")]
418 White,
419}
420
421#[cfg(feature = "sgf-parse")]
422impl From<sgf_parse::Color> for Player {
423 fn from(value: sgf_parse::Color) -> Self {
424 match value {
425 sgf_parse::Color::Black => Player::Black,
426 sgf_parse::Color::White => Player::White,
427 }
428 }
429}
430
431/// A board location in (x, y) format, where (0, 0) is the top-left corner of the board.
432#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
433pub struct Coord(pub u8, pub u8);
434
435impl Coord {
436 /// Converts a coordinate from GTP format.
437 pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
438 let (x_part, y_part) = s
439 .chars()
440 .partition::<String, _>(|c| c.is_ascii_alphabetic());
441
442 let mut x = 0;
443 for c in x_part.chars().map(|c| c.to_ascii_uppercase()) {
444 x *= 25;
445 x += (c as u8) - b'A' + 1;
446 if c == 'I' {
447 return None;
448 } else if c > 'I' {
449 x -= 1;
450 }
451 }
452 x = x.checked_sub(1)?;
453 let y = height.checked_sub(y_part.parse::<u8>().ok()?)?;
454 Some(Self(x, y))
455 }
456
457 /// Converts a coordinate to GTP format.
458 pub fn to_gtp(self, height: u8) -> String {
459 let Self(mut x, y) = self;
460 const LETTERS: &[u8; 25] = b"ABCDEFGHJKLMNOPQRSTUVWXYZ";
461 let mut gtp = String::with_capacity(4);
462 if x >= 25 {
463 gtp.push(LETTERS[(x / 25) as usize - 1] as char);
464 x %= 25;
465 }
466 gtp.push(LETTERS[x as usize] as char);
467 gtp.push_str(&(height - y).to_string());
468 gtp
469 }
470}
471
472#[cfg(feature = "sgf-parse")]
473impl From<sgf_parse::go::Point> for Coord {
474 fn from(c: sgf_parse::go::Point) -> Self {
475 Self(c.x, c.y)
476 }
477}
478
479/// A move in a game.
480#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
481pub enum Move {
482 /// A move placing a stone at the specified coordinate.
483 Move(Coord),
484
485 /// A pass.
486 Pass,
487}
488
489impl Move {
490 /// Converts a move from GTP format.
491 pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
492 if s.eq_ignore_ascii_case("pass") {
493 Some(Self::Pass)
494 } else {
495 Coord::from_gtp(s, height).map(Self::Move)
496 }
497 }
498
499 /// Converts a move to GTP format.
500 pub fn to_gtp(self, height: u8) -> String {
501 match self {
502 Move::Move(coord) => coord.to_gtp(height),
503 Move::Pass => "pass".to_string(),
504 }
505 }
506}
507
508#[cfg(feature = "sgf-parse")]
509impl From<sgf_parse::go::Move> for Move {
510 fn from(m: sgf_parse::go::Move) -> Self {
511 match m {
512 sgf_parse::go::Move::Move(p) => Move::Move(p.into()),
513 sgf_parse::go::Move::Pass => Move::Pass,
514 }
515 }
516}
517
518/// KataGo's version information.
519#[derive(Debug, Clone)]
520pub struct VersionInfo {
521 /// A string indicating the most recent KataGo release version that this version is a descendant of,
522 /// such as `"1.6.1"`.
523 pub version: String,
524
525 /// The precise git hash this KataGo version was compiled from, or the string `"<omitted>"` if KataGo was
526 /// compiled separately from its repo or without Git support.
527 pub git_hash: String,
528}
529
530/// Information about a neural network model.
531#[derive(Debug, Clone, Deserialize)]
532#[serde(rename_all = "camelCase")]
533pub struct Model {
534 /// The model name.
535 pub name: String,
536
537 /// The internal name.
538 pub internal_name: String,
539
540 /// The maximum batch size.
541 pub max_batch_size: u32,
542
543 /// Whether it uses a humanSL profile.
544 #[serde(rename = "usesHumanSLProfile")]
545 pub uses_humansl_profile: bool,
546
547 /// The model version.
548 pub version: u32,
549
550 /// Whether FP16 is used for this model. If this is [`Auto`][Enabled::Auto],
551 /// it will be enabled if the backend deems it to be beneficial.
552 #[serde(rename = "usingFP16")]
553 pub using_fp16: Enabled,
554}
555
556/// The enabled state of a feature.
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
558#[serde(rename_all = "lowercase")]
559pub enum Enabled {
560 /// The feature is disabled.
561 False,
562
563 /// The feature is enabled.
564 True,
565
566 /// The feature will be automatically enabled or disabled based on what the engine thinks is best.
567 Auto,
568}
569
570#[cfg(test)]
571mod tests {
572 use crate::Coord;
573
574 #[test]
575 fn coord_from_gtp() {
576 assert_eq!(Coord::from_gtp("A1", 19), Some(Coord(0, 18)));
577 assert_eq!(Coord::from_gtp("T19", 19), Some(Coord(18, 0)));
578 assert_eq!(Coord::from_gtp("I9", 19), None);
579 assert_eq!(Coord::from_gtp("1", 19), None);
580 assert_eq!(Coord::from_gtp("A", 19), None);
581 assert_eq!(Coord::from_gtp("A20", 19), None);
582 assert_eq!(Coord::from_gtp("Z1", 255), Some(Coord(24, 254)));
583 assert_eq!(Coord::from_gtp("AA1", 255), Some(Coord(25, 254)));
584 assert_eq!(Coord::from_gtp("BB1", 255), Some(Coord(51, 254)));
585 assert_eq!(Coord::from_gtp("JJ1", 255), Some(Coord(233, 254)));
586 }
587
588 #[test]
589 fn coord_to_gtp() {
590 assert_eq!(Coord(0, 18).to_gtp(19), "A1");
591 assert_eq!(Coord(18, 0).to_gtp(19), "T19");
592 assert_eq!(Coord(24, 254).to_gtp(255), "Z1");
593 assert_eq!(Coord(25, 254).to_gtp(255), "AA1");
594 assert_eq!(Coord(51, 254).to_gtp(255), "BB1");
595 assert_eq!(Coord(233, 254).to_gtp(255), "JJ1");
596 }
597}