Skip to main content

katago_analysis/engine/
request.rs

1use std::ops::Not;
2
3use serde::Serialize;
4use serde_json::{Value, json};
5use serde_with::skip_serializing_none;
6
7use crate::{Bonus, Config, Player, Rules};
8
9/// A request to the analysis engine.
10#[derive(Debug, Clone, Serialize)]
11#[serde(into = "Value")]
12pub enum Request {
13    /// Request the engine to analyze one or more positions.
14    Analyze(AnalysisRequest),
15
16    /// Request KataGo's version information.
17    QueryVersion {
18        /// The request ID.
19        id: String,
20    },
21
22    /// Clear the neural network cache.
23    ClearCache {
24        /// The request ID.
25        id: String,
26    },
27
28    /// Terminate a specific analysis request.
29    Terminate {
30        /// The request ID.
31        id: String,
32
33        /// The ID of the request to terminate.
34        terminate_id: String,
35
36        /// If provided, only terminate the analysis for the specified turn numbers.
37        turn_numbers: Option<Vec<usize>>,
38    },
39
40    /// Terminate all pending analysis requests.
41    TerminateAll {
42        /// The request ID.
43        id: String,
44
45        /// If provided, only terminate the analysis for the specified turn numbers.
46        turn_numbers: Option<Vec<usize>>,
47    },
48
49    /// Request information about the available neural network models.
50    QueryModels {
51        /// The request ID.
52        id: String,
53    },
54}
55
56impl From<Request> for Value {
57    fn from(request: Request) -> Self {
58        match request {
59            Request::Analyze(request) => {
60                serde_json::to_value(request).expect("request should be serializable")
61            }
62            Request::QueryVersion { id } => json!({
63                "id": id,
64                "action": "query_version",
65            }),
66            Request::ClearCache { id } => json!({
67                "id": id,
68                "action": "clear_cache",
69            }),
70            Request::Terminate {
71                id,
72                terminate_id,
73                turn_numbers,
74            } => {
75                let mut value = json!({
76                        "id": id,
77                        "action": "terminate",
78                        "terminateId": terminate_id,
79                    }
80                );
81                if let Some(turn_numbers) = turn_numbers {
82                    value
83                        .as_object_mut()
84                        .expect("value should be an object")
85                        .insert("turnNumbers".to_string(), json!(turn_numbers));
86                }
87                value
88            }
89            Request::TerminateAll { id, turn_numbers } => {
90                let mut value = json!({
91                        "id": id,
92                        "action": "terminate_all",
93                    }
94                );
95                if let Some(turn_numbers) = turn_numbers {
96                    value
97                        .as_object_mut()
98                        .expect("value should be an object")
99                        .insert("turnNumbers".to_string(), json!(turn_numbers));
100                }
101                value
102            }
103            Request::QueryModels { id } => json!({
104                "id": id,
105                "action": "query_models",
106            }),
107        }
108    }
109}
110
111/// A game record to be analyzed, along with analysis settings.
112#[skip_serializing_none]
113#[derive(Debug, Clone, Serialize)]
114#[serde(rename_all = "camelCase")]
115pub struct AnalysisRequest {
116    /// The request ID.
117    pub id: String,
118
119    /// The ruleset for this game.
120    pub rules: Rules,
121
122    /// The komi for this game.
123    pub komi: Option<f64>,
124
125    /// Bonus points white receives in handicap games.
126    pub white_handicap_bonus: Option<Bonus>,
127
128    /// The board width.
129    pub board_x_size: u8,
130
131    /// The board height.
132    pub board_y_size: u8,
133
134    /// The stones on the board before the first move.
135    pub initial_stones: Option<Vec<(Player, String)>>,
136
137    /// The player to move in the initial position.
138    pub initial_player: Option<Player>,
139
140    /// The moves played in the game. Move locations can be in GTP format (`"A1"`, `"pass"`, etc.) or explicit
141    /// coordinates (`"(0,0)"`).
142    pub moves: Vec<(Player, String)>,
143
144    /// The positions to analyze, where 0 is the position before the first move.
145    /// If not provided, only the final position will be analyzed.
146    /// The engine will return a separate response for each position.
147    pub analyze_turns: Option<Vec<usize>>,
148
149    /// The maximum number of visits to use.
150    pub max_visits: Option<u32>,
151
152    /// Root policy temperature.
153    pub root_policy_temperature: Option<f64>,
154
155    /// Root FPU reduction max.
156    pub root_fpu_reduction_max: Option<f64>,
157
158    /// The maximum length of the principal variation to return, not including the first move.
159    #[serde(rename = "analysisPVLen")]
160    pub analysis_pv_len: Option<usize>,
161
162    /// Whether to return the ownership prediction.
163    #[serde(skip_serializing_if = "Not::not")]
164    pub include_ownership: bool,
165
166    /// Whether to return the standard deviation of the ownership prediction.
167    #[serde(skip_serializing_if = "Not::not")]
168    pub include_ownership_stdev: bool,
169
170    /// Whether to return the ownership prediction for each move.
171    #[serde(skip_serializing_if = "Not::not")]
172    pub include_moves_ownership: bool,
173
174    /// Whether to return the standard deviation of the ownership prediction for each move.
175    #[serde(skip_serializing_if = "Not::not")]
176    pub include_moves_ownership_stdev: bool,
177
178    /// Whether to return the neural network policy output.
179    #[serde(skip_serializing_if = "Not::not")]
180    pub include_policy: bool,
181
182    /// Whether to return the number of visits for each position in the principal variation.
183    #[serde(rename = "includePVVisits", skip_serializing_if = "Not::not")]
184    pub include_pv_visits: bool,
185
186    /// Whether to return the predicted probability that the game will have a void result.
187    #[serde(skip_serializing_if = "Not::not")]
188    pub include_no_result_value: bool,
189
190    /// Moves which are forbidden.
191    pub avoid_moves: Option<Vec<RestrictedMoves>>,
192
193    /// Moves which are allowed. If specified, all other moves are forbidden.
194    pub allow_moves: Option<Vec<RestrictedMoves>>,
195
196    /// Config overrides for this request.
197    pub override_settings: Option<Config>,
198
199    /// Report partial analysis results every this many seconds.
200    pub report_during_search_every: Option<f64>,
201
202    /// The priority of this request.
203    pub priority: Option<i32>,
204
205    /// The priorities of each position to analyze.
206    pub priorities: Option<Vec<i32>>,
207}
208
209impl AnalysisRequest {
210    /// Creates a new analysis request with the minimum required parameters.
211    pub fn new(
212        id: String,
213        rules: Rules,
214        board_x_size: u8,
215        board_y_size: u8,
216        moves: Vec<(Player, String)>,
217    ) -> Self {
218        Self {
219            id,
220            rules,
221            komi: None,
222            white_handicap_bonus: None,
223            board_x_size,
224            board_y_size,
225            initial_stones: None,
226            initial_player: None,
227            moves,
228            analyze_turns: None,
229            max_visits: None,
230            root_policy_temperature: None,
231            root_fpu_reduction_max: None,
232            analysis_pv_len: None,
233            include_ownership: false,
234            include_ownership_stdev: false,
235            include_moves_ownership: false,
236            include_moves_ownership_stdev: false,
237            include_policy: false,
238            include_pv_visits: false,
239            include_no_result_value: false,
240            avoid_moves: None,
241            allow_moves: None,
242            override_settings: None,
243            report_during_search_every: None,
244            priority: None,
245            priorities: None,
246        }
247    }
248
249    /// Sets komi.
250    pub fn with_komi(mut self, komi: f64) -> Self {
251        self.komi = Some(komi);
252        self
253    }
254
255    /// Sets white's handicap bonus.
256    pub fn with_white_handicap_bonus(mut self, bonus: Bonus) -> Self {
257        self.white_handicap_bonus = Some(bonus);
258        self
259    }
260
261    /// Sets the initial position before the first move.
262    pub fn with_initial_stones(mut self, initial_stones: Vec<(Player, String)>) -> Self {
263        self.initial_stones = Some(initial_stones);
264        self
265    }
266
267    /// Sets the player to move in the initial position.
268    pub fn with_initial_player(mut self, initial_player: Player) -> Self {
269        self.initial_player = Some(initial_player);
270        self
271    }
272
273    /// Analyzes the specified positions. The position before the first move is turn 0.
274    pub fn with_analyze_turns(mut self, analyze_turns: Vec<usize>) -> Self {
275        self.analyze_turns = Some(analyze_turns);
276        self
277    }
278
279    /// Sets the maximum number of visits to use.
280    pub fn with_max_visits(mut self, max_visits: u32) -> Self {
281        self.max_visits = Some(max_visits);
282        self
283    }
284
285    /// Sets the root policy temperature.
286    pub fn with_root_policy_temperature(mut self, root_policy_temperature: f64) -> Self {
287        self.root_policy_temperature = Some(root_policy_temperature);
288        self
289    }
290
291    /// Sets the root FPU reduction max.
292    pub fn with_root_fpu_reduction_max(mut self, root_fpu_reduction_max: f64) -> Self {
293        self.root_fpu_reduction_max = Some(root_fpu_reduction_max);
294        self
295    }
296
297    /// Sets the maximum length of the principal variation to return, not including the first move.
298    pub fn with_analysis_pv_len(mut self, analysis_pv_len: usize) -> Self {
299        self.analysis_pv_len = Some(analysis_pv_len);
300        self
301    }
302
303    /// Includes the ownership prediction.
304    pub fn with_ownership(mut self) -> Self {
305        self.include_ownership = true;
306        self
307    }
308
309    /// Includes the standard deviation of the ownership prediction.
310    pub fn with_ownership_stdev(mut self) -> Self {
311        self.include_ownership_stdev = true;
312        self
313    }
314
315    /// Includes the ownership prediction for each move.
316    pub fn with_moves_ownership(mut self) -> Self {
317        self.include_moves_ownership = true;
318        self
319    }
320
321    /// Includes the standard deviation of the ownership prediction for each move.
322    pub fn with_moves_ownership_stdev(mut self) -> Self {
323        self.include_moves_ownership_stdev = true;
324        self
325    }
326
327    /// Includes the neural network policy output.
328    pub fn with_policy(mut self) -> Self {
329        self.include_policy = true;
330        self
331    }
332
333    /// Includes the number of visits for each position in the principal variation.
334    pub fn with_pv_visits(mut self) -> Self {
335        self.include_pv_visits = true;
336        self
337    }
338
339    /// Includes the predicted probability that the game will have a void result.
340    pub fn with_no_result_value(mut self) -> Self {
341        self.include_no_result_value = true;
342        self
343    }
344
345    /// Sets moves which are forbidden.
346    pub fn with_avoid_moves(mut self, avoid_moves: Vec<RestrictedMoves>) -> Self {
347        self.avoid_moves = Some(avoid_moves);
348        self
349    }
350
351    /// Sets moves which are allowed.
352    pub fn with_allow_moves(mut self, allow_moves: Vec<RestrictedMoves>) -> Self {
353        self.allow_moves = Some(allow_moves);
354        self
355    }
356
357    /// Overrides config settings for this request.
358    pub fn with_override_settings(mut self, config: Config) -> Self {
359        self.override_settings = Some(config);
360        self
361    }
362
363    /// Gets partial analysis results every this many seconds.
364    pub fn with_report_during_search_every(mut self, seconds: f64) -> Self {
365        self.report_during_search_every = Some(seconds);
366        self
367    }
368
369    /// Sets the priority of this request.
370    pub fn with_priority(mut self, priority: i32) -> Self {
371        self.priority = Some(priority);
372        self
373    }
374
375    /// Sets the priorities of each position to analyze.
376    pub fn with_priorities(mut self, priorities: Vec<i32>) -> Self {
377        self.priorities = Some(priorities);
378        self
379    }
380}
381
382/// A list of moves that are either forbidden with [`AnalysisRequest::avoid_moves`] or allowed with
383/// [`AnalysisRequest::allow_moves`].
384#[derive(Debug, Clone, Serialize)]
385#[serde(rename_all = "camelCase")]
386pub struct RestrictedMoves {
387    /// The player the move restriction applies to.
388    pub player: Player,
389
390    /// The list of moves.
391    pub moves: Vec<String>,
392
393    /// The search depth within which the restriction applies.
394    pub until_depth: u32,
395}