Skip to main content

katago_analysis/
request.rs

1#[cfg(feature = "sgf-parse")]
2use sgf_parse::{SgfNode, go::Prop};
3
4use crate::*;
5
6/// A game record to be analyzed, along with analysis settings.
7#[derive(Debug, Clone)]
8pub struct AnalysisRequest {
9    /// The ruleset for this game.
10    pub rules: Rules,
11
12    /// The komi for this game.
13    pub komi: Option<f64>,
14
15    /// Bonus points white receives in handicap games.
16    pub white_handicap_bonus: Option<Bonus>,
17
18    /// The board width.
19    pub board_x_size: u8,
20
21    /// The board height.
22    pub board_y_size: u8,
23
24    /// The stones on the board before the first move.
25    pub initial_stones: Option<Vec<(Player, Coord)>>,
26
27    /// The player to move in the initial position.
28    pub initial_player: Option<Player>,
29
30    /// The moves played in the game.
31    pub moves: Vec<(Player, Move)>,
32
33    /// The maximum number of visits to use.
34    pub max_visits: Option<u32>,
35
36    /// Root policy temperature.
37    pub root_policy_temperature: Option<f64>,
38
39    /// Root FPU reduction max.
40    pub root_fpu_reduction_max: Option<f64>,
41
42    /// The maximum length of the principal variation to return, not including the first move.
43    pub analysis_pv_len: Option<usize>,
44
45    /// Whether to return the ownership prediction.
46    pub include_ownership: bool,
47
48    /// Whether to return the standard deviation of the ownership prediction.
49    pub include_ownership_stdev: bool,
50
51    /// Whether to return the ownership prediction for each move.
52    pub include_moves_ownership: bool,
53
54    /// Whether to return the standard deviation of the ownership prediction for each move.
55    pub include_moves_ownership_stdev: bool,
56
57    /// Whether to return the neural network policy output.
58    pub include_policy: bool,
59
60    /// Whether to return the number of visits for each position in the principal variation.
61    pub include_pv_visits: bool,
62
63    /// Whether to return the predicted probability that the game will have a void result.
64    pub include_no_result_value: bool,
65
66    /// Moves which are forbidden.
67    pub avoid_moves: Option<Vec<RestrictedMoves>>,
68
69    /// Moves which are allowed. If specified, all other moves are forbidden.
70    pub allow_moves: Option<Vec<RestrictedMoves>>,
71
72    /// Config overrides for this request.
73    pub override_settings: Option<Config>,
74
75    /// Report partial analysis results every this many seconds.
76    pub report_during_search_every: Option<f64>,
77
78    /// The priority of this request.
79    pub priority: Option<i32>,
80}
81
82impl AnalysisRequest {
83    /// Creates a new analysis request with the minimum required parameters.
84    pub fn new(
85        rules: Rules,
86        board_x_size: u8,
87        board_y_size: u8,
88        moves: Vec<(Player, Move)>,
89    ) -> Self {
90        Self {
91            rules,
92            komi: None,
93            white_handicap_bonus: None,
94            board_x_size,
95            board_y_size,
96            initial_stones: None,
97            initial_player: None,
98            moves,
99            max_visits: None,
100            root_policy_temperature: None,
101            root_fpu_reduction_max: None,
102            analysis_pv_len: None,
103            include_ownership: false,
104            include_ownership_stdev: false,
105            include_moves_ownership: false,
106            include_moves_ownership_stdev: false,
107            include_policy: false,
108            include_pv_visits: false,
109            include_no_result_value: false,
110            avoid_moves: None,
111            allow_moves: None,
112            override_settings: None,
113            report_during_search_every: None,
114            priority: None,
115        }
116    }
117
118    /// Converts this request into the lower-level equivalent used by the [`engine`] module.
119    ///
120    /// You probably don't need to use this unless you're directly using the lower-level API in the [`engine`] module.
121    pub fn into_engine_request(
122        self,
123        id: String,
124        analyze_turns: Vec<usize>,
125        priorities: Option<Vec<i32>>,
126    ) -> engine::AnalysisRequest {
127        engine::AnalysisRequest {
128            id,
129            rules: self.rules,
130            komi: self.komi,
131            white_handicap_bonus: self.white_handicap_bonus,
132            board_x_size: self.board_x_size,
133            board_y_size: self.board_y_size,
134            initial_stones: self.initial_stones.map(|s| {
135                s.into_iter()
136                    .map(|(p, c)| (p, c.to_gtp(self.board_y_size)))
137                    .collect()
138            }),
139            initial_player: self.initial_player,
140            moves: self
141                .moves
142                .into_iter()
143                .map(|(p, m)| (p, m.to_gtp(self.board_y_size)))
144                .collect(),
145            analyze_turns: Some(analyze_turns),
146            max_visits: self.max_visits,
147            root_policy_temperature: self.root_policy_temperature,
148            root_fpu_reduction_max: self.root_fpu_reduction_max,
149            analysis_pv_len: self.analysis_pv_len,
150            include_ownership: self.include_ownership,
151            include_ownership_stdev: self.include_ownership_stdev,
152            include_moves_ownership: self.include_moves_ownership,
153            include_moves_ownership_stdev: self.include_moves_ownership_stdev,
154            include_policy: self.include_policy,
155            include_pv_visits: self.include_pv_visits,
156            include_no_result_value: self.include_no_result_value,
157            avoid_moves: self.avoid_moves.map(|m| {
158                m.into_iter()
159                    .map(|rm| rm.into_engine_restricted_moves(self.board_y_size))
160                    .collect()
161            }),
162            allow_moves: self.allow_moves.map(|m| {
163                m.into_iter()
164                    .map(|rm| rm.into_engine_restricted_moves(self.board_y_size))
165                    .collect()
166            }),
167            override_settings: self.override_settings,
168            report_during_search_every: self.report_during_search_every,
169            priority: self.priority,
170            priorities,
171        }
172    }
173
174    /// Sets komi.
175    pub fn with_komi(mut self, komi: f64) -> Self {
176        self.komi = Some(komi);
177        self
178    }
179
180    /// Sets white's handicap bonus.
181    pub fn with_white_handicap_bonus(mut self, bonus: Bonus) -> Self {
182        self.white_handicap_bonus = Some(bonus);
183        self
184    }
185
186    /// Sets the initial position before the first move.
187    pub fn with_initial_stones(mut self, initial_stones: Vec<(Player, Coord)>) -> Self {
188        self.initial_stones = Some(initial_stones);
189        self
190    }
191
192    /// Sets the player to move in the initial position.
193    pub fn with_initial_player(mut self, initial_player: Player) -> Self {
194        self.initial_player = Some(initial_player);
195        self
196    }
197
198    /// Sets the maximum number of visits to use.
199    pub fn with_max_visits(mut self, max_visits: u32) -> Self {
200        self.max_visits = Some(max_visits);
201        self
202    }
203
204    /// Sets the root policy temperature.
205    pub fn with_root_policy_temperature(mut self, root_policy_temperature: f64) -> Self {
206        self.root_policy_temperature = Some(root_policy_temperature);
207        self
208    }
209
210    /// Sets the root FPU reduction max.
211    pub fn with_root_fpu_reduction_max(mut self, root_fpu_reduction_max: f64) -> Self {
212        self.root_fpu_reduction_max = Some(root_fpu_reduction_max);
213        self
214    }
215
216    /// Sets the maximum length of the principal variation to return, not including the first move.
217    pub fn with_analysis_pv_len(mut self, analysis_pv_len: usize) -> Self {
218        self.analysis_pv_len = Some(analysis_pv_len);
219        self
220    }
221
222    /// Includes the ownership prediction.
223    pub fn with_ownership(mut self) -> Self {
224        self.include_ownership = true;
225        self
226    }
227
228    /// Includes the standard deviation of the ownership prediction.
229    pub fn with_ownership_stdev(mut self) -> Self {
230        self.include_ownership_stdev = true;
231        self
232    }
233
234    /// Includes the ownership prediction for each move.
235    pub fn with_moves_ownership(mut self) -> Self {
236        self.include_moves_ownership = true;
237        self
238    }
239
240    /// Includes the standard deviation of the ownership prediction for each move.
241    pub fn with_moves_ownership_stdev(mut self) -> Self {
242        self.include_moves_ownership_stdev = true;
243        self
244    }
245
246    /// Includes the neural network policy output.
247    pub fn with_policy(mut self) -> Self {
248        self.include_policy = true;
249        self
250    }
251
252    /// Includes the number of visits for each position in the principal variation.
253    pub fn with_pv_visits(mut self) -> Self {
254        self.include_pv_visits = true;
255        self
256    }
257
258    /// Includes the predicted probability that the game will have a void result.
259    pub fn with_no_result_value(mut self) -> Self {
260        self.include_no_result_value = true;
261        self
262    }
263
264    /// Sets moves which are forbidden.
265    pub fn with_avoid_moves(mut self, avoid_moves: Vec<RestrictedMoves>) -> Self {
266        self.avoid_moves = Some(avoid_moves);
267        self
268    }
269
270    /// Sets moves which are allowed.
271    pub fn with_allow_moves(mut self, allow_moves: Vec<RestrictedMoves>) -> Self {
272        self.allow_moves = Some(allow_moves);
273        self
274    }
275
276    /// Overrides config settings for this request.
277    pub fn with_override_settings(mut self, config: Config) -> Self {
278        self.override_settings = Some(config);
279        self
280    }
281
282    /// Gets partial analysis results every this many seconds.
283    pub fn with_report_during_search_every(mut self, seconds: f64) -> Self {
284        self.report_during_search_every = Some(seconds);
285        self
286    }
287
288    /// Sets the priority of this request.
289    pub fn with_priority(mut self, priority: i32) -> Self {
290        self.priority = Some(priority);
291        self
292    }
293}
294
295#[cfg(feature = "sgf-parse")]
296impl From<&SgfNode<Prop>> for AnalysisRequest {
297    /// Creates an analysis request from the root [`SgfNode`] of a game tree.
298    ///
299    /// This will set [`rules`](AnalysisRequest::rules), [`komi`](AnalysisRequest::komi) (if present),
300    /// [`board_x_size`](AnalysisRequest::board_x_size), [`board_y_size`](AnalysisRequest::board_y_size),
301    /// [`initial_stones`](AnalysisRequest::initial_stones) (if present),
302    /// [`initial_player`](AnalysisRequest::initial_player) (if present), and [`moves`](AnalysisRequest::moves),
303    /// based on the SGF data.
304    ///
305    /// Rules are determined by the first of the following that applies:
306    /// - If `RU` is present, its value will be used as a [named ruleset](Rules::Named).
307    /// - If `KM` is present and greater than 6.5, [Chinese rules](Rules::chinese) will be used.
308    /// - Otherwise, [Japanese rules](Rules::japanese) will be used.
309    fn from(root: &SgfNode<Prop>) -> Self {
310        let (width, height) = match root.get_property("SZ") {
311            Some(Prop::SZ((w, h))) => (*w, *h),
312            _ => (19, 19),
313        };
314
315        let komi = match root.get_property("KM") {
316            Some(Prop::KM(k)) => Some(*k),
317            _ => None,
318        };
319
320        let rules = match root.get_property("RU") {
321            Some(Prop::RU(r)) => Rules::Named(r.text.clone()),
322            _ => match komi {
323                Some(k) if k > 6.5 => Rules::chinese(),
324                _ => Rules::japanese(),
325            },
326        };
327
328        let moves: Vec<(Player, Move)> = root
329            .main_variation()
330            .filter_map(|m| match m.get_move() {
331                Some(Prop::B(m)) => Some((Player::Black, (*m).into())),
332                Some(Prop::W(m)) => Some((Player::White, (*m).into())),
333                _ => None,
334            })
335            .collect();
336
337        let mut initial_stones: Vec<(Player, Coord)> = vec![];
338        if let Some(Prop::AB(ps)) = root.get_property("AB") {
339            initial_stones.extend(ps.iter().map(|p| (Player::Black, (*p).into())));
340        }
341        if let Some(Prop::AW(ps)) = root.get_property("AW") {
342            initial_stones.extend(ps.iter().map(|p| (Player::White, (*p).into())));
343        }
344
345        let initial_player: Option<Player> = match root.get_property("PL") {
346            Some(Prop::PL(p)) => Some((*p).into()),
347            _ => None,
348        };
349
350        let mut request = Self::new(rules, width, height, moves);
351        request.komi = komi;
352        if !initial_stones.is_empty() {
353            request = request.with_initial_stones(initial_stones);
354        }
355        request.initial_player = initial_player;
356        request
357    }
358}
359
360/// A list of moves that are either forbidden with [`AnalysisRequest::avoid_moves`] or allowed with
361/// [`AnalysisRequest::allow_moves`].
362#[derive(Debug, Clone)]
363pub struct RestrictedMoves {
364    /// The player the move restriction applies to.
365    pub player: Player,
366
367    /// The list of moves.
368    pub moves: Vec<Move>,
369
370    /// The search depth within which the restriction applies.
371    pub until_depth: u32,
372}
373
374impl RestrictedMoves {
375    /// Converts this restriction into the lower-level equivalent used by the [`engine`] module.
376    ///
377    /// You probably don't need to use this unless you're directly using the lower-level API in the [`engine`] module.
378    pub fn into_engine_restricted_moves(self, height: u8) -> engine::RestrictedMoves {
379        engine::RestrictedMoves {
380            player: self.player,
381            moves: self.moves.into_iter().map(|m| m.to_gtp(height)).collect(),
382            until_depth: self.until_depth,
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    #[cfg(feature = "sgf-parse")]
390    mod sgf {
391        use std::collections::HashSet;
392
393        use crate::{AnalysisRequest, Coord, Move, Player, Rules};
394
395        #[test]
396        fn from_sgf() {
397            let sgf = "(;;B[pd](;B[dp];W[])(;W[dp];B[pp]))";
398            let request =
399                AnalysisRequest::from(sgf_parse::go::parse(sgf).unwrap().first().unwrap());
400            assert_eq!(request.rules, Rules::japanese());
401            assert_eq!(request.komi, None);
402            assert_eq!(request.board_x_size, 19);
403            assert_eq!(request.board_y_size, 19);
404            assert_eq!(request.initial_stones, None);
405            assert_eq!(request.initial_player, None);
406            assert_eq!(
407                request.moves,
408                vec![
409                    (Player::Black, Move::Move(Coord(15, 3))),
410                    (Player::Black, Move::Move(Coord(3, 15))),
411                    (Player::White, Move::Pass),
412                ]
413            );
414        }
415
416        #[test]
417        fn size() {
418            let sgf = "(;SZ[9:13];B[aa];W[im])";
419            let request =
420                AnalysisRequest::from(sgf_parse::go::parse(sgf).unwrap().first().unwrap());
421            assert_eq!(request.rules, Rules::japanese());
422            assert_eq!(request.komi, None);
423            assert_eq!(request.board_x_size, 9);
424            assert_eq!(request.board_y_size, 13);
425            assert_eq!(request.initial_stones, None);
426            assert_eq!(request.initial_player, None);
427            assert_eq!(
428                request.moves,
429                vec![
430                    (Player::Black, Move::Move(Coord(0, 0))),
431                    (Player::White, Move::Move(Coord(8, 12))),
432                ]
433            );
434        }
435
436        #[test]
437        fn komi() {
438            let sgf = "(;KM[7.5];B[pd];W[dp])";
439            let request =
440                AnalysisRequest::from(sgf_parse::go::parse(sgf).unwrap().first().unwrap());
441            assert_eq!(request.rules, Rules::chinese());
442            assert_eq!(request.komi, Some(7.5));
443            assert_eq!(request.board_x_size, 19);
444            assert_eq!(request.board_y_size, 19);
445            assert_eq!(request.initial_stones, None);
446            assert_eq!(request.initial_player, None);
447            assert_eq!(
448                request.moves,
449                vec![
450                    (Player::Black, Move::Move(Coord(15, 3))),
451                    (Player::White, Move::Move(Coord(3, 15))),
452                ]
453            );
454        }
455
456        #[test]
457        fn rules() {
458            let sgf = "(;RU[aga];B[pd];W[dp])";
459            let request =
460                AnalysisRequest::from(sgf_parse::go::parse(sgf).unwrap().first().unwrap());
461            assert_eq!(request.rules, Rules::Named("aga".to_string()));
462            assert_eq!(request.komi, None);
463            assert_eq!(request.board_x_size, 19);
464            assert_eq!(request.board_y_size, 19);
465            assert_eq!(request.initial_stones, None);
466            assert_eq!(request.initial_player, None);
467            assert_eq!(
468                request.moves,
469                vec![
470                    (Player::Black, Move::Move(Coord(15, 3))),
471                    (Player::White, Move::Move(Coord(3, 15))),
472                ]
473            );
474        }
475
476        #[test]
477        fn initial_stones() {
478            let sgf = "(;AB[pd][dp]AW[dd][pp];B[cc];W[qc])";
479            let request =
480                AnalysisRequest::from(sgf_parse::go::parse(sgf).unwrap().first().unwrap());
481            assert_eq!(request.rules, Rules::japanese());
482            assert_eq!(request.komi, None);
483            assert_eq!(request.board_x_size, 19);
484            assert_eq!(request.board_y_size, 19);
485            assert_eq!(
486                request
487                    .initial_stones
488                    .map(|s| HashSet::<(Player, Coord)>::from_iter(s)),
489                Some(HashSet::from_iter(vec![
490                    (Player::Black, Coord(15, 3)),
491                    (Player::Black, Coord(3, 15)),
492                    (Player::White, Coord(3, 3)),
493                    (Player::White, Coord(15, 15)),
494                ]))
495            );
496            assert_eq!(request.initial_player, None);
497            assert_eq!(
498                request.moves,
499                vec![
500                    (Player::Black, Move::Move(Coord(2, 2))),
501                    (Player::White, Move::Move(Coord(16, 2))),
502                ]
503            );
504        }
505
506        #[test]
507        fn initial_player() {
508            let sgf = "(;PL[W];B[pd];W[dp])";
509            let request =
510                AnalysisRequest::from(sgf_parse::go::parse(sgf).unwrap().first().unwrap());
511            assert_eq!(request.rules, Rules::japanese());
512            assert_eq!(request.komi, None);
513            assert_eq!(request.board_x_size, 19);
514            assert_eq!(request.board_y_size, 19);
515            assert_eq!(request.initial_stones, None);
516            assert_eq!(request.initial_player, Some(Player::White));
517            assert_eq!(
518                request.moves,
519                vec![
520                    (Player::Black, Move::Move(Coord(15, 3))),
521                    (Player::White, Move::Move(Coord(3, 15))),
522                ]
523            );
524        }
525    }
526}