use crate::action::{Action, ActionType};
use crate::parser::mjai_to_tid;
use super::Observation;
#[allow(dead_code)]
pub const SPARSE_VOCAB_SIZE: usize = 442;
#[allow(dead_code)]
pub const SPARSE_PAD: u16 = 441;
pub const MAX_SPARSE_LEN: usize = 25;
#[allow(dead_code)]
pub const PROG_DIMS: [u16; 5] = [5, 277, 3, 3, 5];
pub const MAX_PROG_LEN: usize = 512;
#[allow(dead_code)]
pub const PROG_PAD: [u16; 5] = [4, 276, 2, 2, 4];
#[allow(dead_code)]
pub const CAND_DIMS: [u16; 4] = [280, 3, 3, 4];
#[allow(dead_code)]
pub const MAX_CAND_LEN: usize = 64;
#[allow(dead_code)]
pub const CAND_PAD: [u16; 4] = [279, 2, 2, 3];
pub const NUM_NUMERIC: usize = 12;
pub fn tile_id_to_kan37(tile_id: u32) -> u8 {
if tile_id == 16 {
return 0; }
if tile_id == 52 {
return 10; }
if tile_id == 88 {
return 20; }
let tile_type = (tile_id / 4) as u8; tile_type_to_kan37(tile_type)
}
fn tile_type_to_kan37(tile_type: u8) -> u8 {
match tile_type {
0..=8 => tile_type + 1, 9..=17 => tile_type + 2, 18..=26 => tile_type + 3, 27..=33 => tile_type + 3, _ => 0,
}
}
fn mjai_tile_to_kan37(mjai: &str) -> Option<u8> {
let tid = mjai_to_tid(mjai)?;
Some(tile_id_to_kan37(tid as u32))
}
pub fn encode_chi(consumed: &[u8], called_tile: u8) -> u16 {
let mut all_tiles = vec![called_tile];
all_tiles.extend_from_slice(consumed);
all_tiles.sort();
let first_type = all_tiles[0] / 4;
let suit = first_type / 9; let suit_base = suit * 9;
let seq_start = first_type - suit_base;
let called_type = called_tile / 4;
let call_pos = called_type - suit_base - seq_start;
let has_red = all_tiles.iter().any(|&t| t == 16 || t == 52 || t == 88);
let five_in_seq = (suit_base + 4) >= (suit_base + seq_start)
&& (suit_base + 4) <= (suit_base + seq_start + 2);
let involves_five = five_in_seq && (seq_start..=seq_start + 2).contains(&4);
let suit_offset = (suit as u16) * 30;
let mut offset: u16 = 0;
for s in 0..seq_start {
let seq_has_five = (s..=s + 2).contains(&4);
offset += if seq_has_five { 6 } else { 3 };
}
let sub_idx = if involves_five && has_red {
3 + call_pos } else {
call_pos
};
suit_offset + offset + sub_idx as u16
}
pub fn encode_pon(consumed: &[u8], called_tile: u8) -> u16 {
let called_type = called_tile / 4;
let suit = called_type / 9;
if suit == 3 {
let honor_idx = called_type - 27;
return 33 + honor_idx as u16; }
let suit_base = suit * 9;
let rank = called_type - suit_base;
let suit_offset = (suit as u16) * 11;
if rank == 4 {
let called_is_red = called_tile == 16 || called_tile == 52 || called_tile == 88;
let consumed_has_red = consumed.iter().any(|&t| t == 16 || t == 52 || t == 88);
let sub_idx = if called_is_red {
2 } else if consumed_has_red {
1 } else {
0 };
suit_offset + 4 + sub_idx
} else {
let idx = if rank < 4 {
rank as u16
} else {
(rank as u16) + 2
};
suit_offset + idx
}
}
fn relative_from(actor: u8, target: u8) -> u8 {
((target as i8 - actor as i8 + 3) % 4) as u8
}
fn parse_consumed_tids_from_value(v: &serde_json::Value) -> Vec<u8> {
let mut tids = Vec::new();
if let Some(arr) = v["consumed"].as_array() {
for item in arr {
if let Some(s) = item.as_str()
&& let Some(tid) = mjai_to_tid(s)
{
tids.push(tid);
}
}
}
tids
}
pub fn process_single_event_progression(
event: &serde_json::Value,
pending_reach_actor: &mut Option<u8>,
) -> Option<[u16; 5]> {
let event_type = event["type"].as_str()?;
match event_type {
"start_kyoku" => Some([4, 0, 2, 2, 4]),
"reach" => {
if let Some(actor) = event["actor"].as_u64() {
*pending_reach_actor = Some(actor as u8);
}
None
}
"dahai" => {
let actor = event["actor"].as_u64().unwrap_or(0) as u8;
let pai = event["pai"].as_str().unwrap_or("?");
if pai == "?" {
return None;
}
let k37 = mjai_tile_to_kan37(pai)?;
let type_idx = 1 + k37 as u16;
let moqie = if event["tsumogiri"].as_bool().unwrap_or(false) {
1
} else {
0
};
let liqi = if *pending_reach_actor == Some(actor) {
*pending_reach_actor = None;
1
} else {
0
};
Some([actor as u16, type_idx, moqie, liqi, 4])
}
"chi" => {
let actor = event["actor"].as_u64().unwrap_or(0) as u8;
let target = event["target"].as_u64().unwrap_or(0) as u8;
let pai = event["pai"].as_str().unwrap_or("?");
if pai == "?" {
return None;
}
let called_tid = mjai_to_tid(pai)?;
let consumed = parse_consumed_tids_from_value(event);
if consumed.len() < 2 {
return None;
}
let type_idx = 38 + encode_chi(&consumed, called_tid);
let rel = relative_from(actor, target);
Some([actor as u16, type_idx, 2, 2, rel as u16])
}
"pon" => {
let actor = event["actor"].as_u64().unwrap_or(0) as u8;
let target = event["target"].as_u64().unwrap_or(0) as u8;
let pai = event["pai"].as_str().unwrap_or("?");
if pai == "?" {
return None;
}
let called_tid = mjai_to_tid(pai)?;
let consumed = parse_consumed_tids_from_value(event);
if consumed.len() < 2 {
return None;
}
let type_idx = 128 + encode_pon(&consumed, called_tid);
let rel = relative_from(actor, target);
Some([actor as u16, type_idx, 2, 2, rel as u16])
}
"daiminkan" => {
let actor = event["actor"].as_u64().unwrap_or(0) as u8;
let target = event["target"].as_u64().unwrap_or(0) as u8;
let pai = event["pai"].as_str().unwrap_or("?");
if pai == "?" {
return None;
}
let k37 = mjai_tile_to_kan37(pai)?;
let type_idx = 168 + k37 as u16;
let rel = relative_from(actor, target);
Some([actor as u16, type_idx, 2, 2, rel as u16])
}
"ankan" => {
let actor = event["actor"].as_u64().unwrap_or(0) as u8;
let consumed = parse_consumed_tids_from_value(event);
if consumed.is_empty() {
return None;
}
let tile34 = consumed[0] / 4;
let type_idx = 205 + tile34 as u16;
Some([actor as u16, type_idx, 2, 2, 4])
}
"kakan" => {
let actor = event["actor"].as_u64().unwrap_or(0) as u8;
let pai = event["pai"].as_str().unwrap_or("?");
if pai == "?" {
return None;
}
let k37 = mjai_tile_to_kan37(pai)?;
let type_idx = 239 + k37 as u16;
Some([actor as u16, type_idx, 2, 2, 4])
}
_ => None,
}
}
impl Observation {
pub fn encode_seq_sparse(&self, game_style: u8) -> Vec<u16> {
let mut tokens: Vec<u16> = Vec::with_capacity(MAX_SPARSE_LEN);
tokens.push(game_style.min(1) as u16);
tokens.push(2 + self.player_id.min(3) as u16);
tokens.push(6 + self.round_wind.min(2) as u16);
tokens.push(9 + self.oya.min(3) as u16);
let tiles_remaining = self.count_tiles_remaining();
tokens.push(13 + tiles_remaining.min(69));
for (i, &dora_tid) in self.dora_indicators.iter().enumerate() {
if i >= 5 {
break;
}
let k37 = tile_id_to_kan37(dora_tid);
tokens.push(83 + (i as u16) * 37 + k37 as u16);
}
let my_hand = &self.hands[self.player_id as usize];
for &tid in my_hand {
let tid = tid as u16;
if tid < 136 {
tokens.push(268 + tid);
}
}
if let Some(drawn) = self.get_drawn_tile() {
let k37 = tile_id_to_kan37(drawn as u32);
tokens.push(404 + k37 as u16);
}
tokens
}
fn count_tiles_remaining(&self) -> u16 {
let n = 4; let total_tiles: u32 = 136;
let mut used: u32 = 0;
for i in 0..n {
used += self.hands[i].len() as u32;
}
for i in 0..n {
used += self.discards[i].len() as u32;
}
for i in 0..n {
for meld in &self.melds[i] {
used += meld.tiles.len() as u32;
}
}
used += self.dora_indicators.len() as u32;
let wall_size = total_tiles.saturating_sub(14 + used);
wall_size as u16
}
fn get_drawn_tile(&self) -> Option<u8> {
for event_str in self.events.iter().rev() {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(event_str) {
let event_type = v["type"].as_str().unwrap_or("");
if event_type == "tsumo" {
let actor = v["actor"].as_u64();
if actor == Some(self.player_id as u64)
&& let Some(pai) = v["pai"].as_str()
&& pai != "?"
{
return mjai_to_tid(pai);
}
}
if event_type == "dahai"
|| event_type == "chi"
|| event_type == "pon"
|| event_type == "daiminkan"
{
break;
}
}
}
None
}
pub fn encode_seq_numeric(&self) -> [f32; NUM_NUMERIC] {
let mut out = [0.0f32; NUM_NUMERIC];
let pid = self.player_id as usize;
out[0] = self.honba as f32;
out[1] = self.riichi_sticks as f32;
for i in 0..4 {
let seat = (pid + i) % 4;
out[2 + i] = self.scores[seat] as f32;
}
let (start_honba, start_riichi, start_scores) = self.parse_start_kyoku_info();
out[6] = start_honba as f32;
out[7] = start_riichi as f32;
for i in 0..4 {
let seat = (pid + i) % 4;
out[8 + i] = start_scores[seat] as f32;
}
out
}
fn parse_start_kyoku_info(&self) -> (u32, u32, [i32; 4]) {
for event_str in &self.events {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(event_str)
&& v["type"].as_str() == Some("start_kyoku")
{
let honba = v["honba"].as_u64().unwrap_or(0) as u32;
let kyotaku = v["kyotaku"].as_u64().unwrap_or(0) as u32;
let mut scores = [0i32; 4];
if let Some(arr) = v["scores"].as_array() {
for (i, val) in arr.iter().enumerate().take(4) {
scores[i] = val.as_i64().unwrap_or(0) as i32;
}
}
return (honba, kyotaku, scores);
}
}
(self.honba as u32, self.riichi_sticks, self.scores)
}
pub fn encode_seq_progression(&self) -> Vec<[u16; 5]> {
if let Some(ref cached) = self.cached_progression {
return cached.clone();
}
let mut prog: Vec<[u16; 5]> = Vec::with_capacity(128);
let mut pending_reach_actor: Option<u8> = None;
for event_str in &self.events {
let v = match serde_json::from_str::<serde_json::Value>(event_str) {
Ok(v) => v,
Err(_) => continue,
};
let event_type = match v["type"].as_str() {
Some(t) => t,
None => continue,
};
match event_type {
"start_kyoku" => {
prog.push([4, 0, 2, 2, 4]);
}
"reach" => {
if let Some(actor) = v["actor"].as_u64() {
pending_reach_actor = Some(actor as u8);
}
}
"dahai" => {
let actor = v["actor"].as_u64().unwrap_or(0) as u8;
let pai = v["pai"].as_str().unwrap_or("?");
let tsumogiri = v["tsumogiri"].as_bool().unwrap_or(false);
if pai == "?" {
continue; }
let k37 = match mjai_tile_to_kan37(pai) {
Some(k) => k,
None => continue,
};
let type_idx = 1 + k37 as u16; let moqie = if tsumogiri { 1 } else { 0 };
let liqi = if pending_reach_actor == Some(actor) {
pending_reach_actor = None;
1
} else {
0
};
prog.push([actor as u16, type_idx, moqie, liqi, 4]);
}
"chi" => {
let actor = v["actor"].as_u64().unwrap_or(0) as u8;
let target = v["target"].as_u64().unwrap_or(0) as u8;
let pai = v["pai"].as_str().unwrap_or("?");
if pai == "?" {
continue;
}
let called_tid = match mjai_to_tid(pai) {
Some(t) => t,
None => continue,
};
let consumed = self.parse_consumed_tids(&v);
if consumed.len() < 2 {
continue;
}
let chi_enc = encode_chi(&consumed, called_tid);
let type_idx = 38 + chi_enc; let rel = relative_from(actor, target);
prog.push([actor as u16, type_idx, 2, 2, rel as u16]);
}
"pon" => {
let actor = v["actor"].as_u64().unwrap_or(0) as u8;
let target = v["target"].as_u64().unwrap_or(0) as u8;
let pai = v["pai"].as_str().unwrap_or("?");
if pai == "?" {
continue;
}
let called_tid = match mjai_to_tid(pai) {
Some(t) => t,
None => continue,
};
let consumed = self.parse_consumed_tids(&v);
if consumed.len() < 2 {
continue;
}
let pon_enc = encode_pon(&consumed, called_tid);
let type_idx = 128 + pon_enc; let rel = relative_from(actor, target);
prog.push([actor as u16, type_idx, 2, 2, rel as u16]);
}
"daiminkan" => {
let actor = v["actor"].as_u64().unwrap_or(0) as u8;
let target = v["target"].as_u64().unwrap_or(0) as u8;
let pai = v["pai"].as_str().unwrap_or("?");
if pai == "?" {
continue;
}
let k37 = match mjai_tile_to_kan37(pai) {
Some(k) => k,
None => continue,
};
let type_idx = 168 + k37 as u16; let rel = relative_from(actor, target);
prog.push([actor as u16, type_idx, 2, 2, rel as u16]);
}
"ankan" => {
let actor = v["actor"].as_u64().unwrap_or(0) as u8;
let consumed = self.parse_consumed_tids(&v);
if consumed.is_empty() {
continue;
}
let tile34 = consumed[0] / 4;
let type_idx = 205 + tile34 as u16;
prog.push([actor as u16, type_idx, 2, 2, 4]);
}
"kakan" => {
let actor = v["actor"].as_u64().unwrap_or(0) as u8;
let pai = v["pai"].as_str().unwrap_or("?");
if pai == "?" {
continue;
}
let k37 = match mjai_tile_to_kan37(pai) {
Some(k) => k,
None => continue,
};
let type_idx = 239 + k37 as u16;
prog.push([actor as u16, type_idx, 2, 2, 4]);
}
_ => {
}
}
if prog.len() >= MAX_PROG_LEN {
break;
}
}
prog
}
fn parse_consumed_tids(&self, v: &serde_json::Value) -> Vec<u8> {
let mut tids = Vec::new();
if let Some(arr) = v["consumed"].as_array() {
for item in arr {
if let Some(s) = item.as_str()
&& let Some(tid) = mjai_to_tid(s)
{
tids.push(tid);
}
}
}
tids
}
pub fn encode_seq_candidates(&self) -> Vec<[u16; 4]> {
let mut cands: Vec<[u16; 4]> = Vec::with_capacity(64);
let pid = self.player_id;
let has_riichi = self
._legal_actions
.iter()
.any(|a| a.action_type == ActionType::Riichi);
for action in &self._legal_actions {
let tuple = self.encode_candidate_action(action, pid, has_riichi);
if let Some(t) = tuple {
cands.push(t);
}
}
cands
}
fn encode_candidate_action(
&self,
action: &Action,
pid: u8,
_has_riichi: bool,
) -> Option<[u16; 4]> {
match action.action_type {
ActionType::Discard => {
let tile = action.tile?;
let k37 = tile_id_to_kan37(tile as u32);
let type_idx = k37 as u16;
let moqie = if self.is_tsumogiri_candidate(tile) {
1
} else {
0
};
Some([type_idx, moqie, 2, 3]) }
ActionType::Riichi => {
None
}
ActionType::Ankan => {
let first = *action.consume_tiles.first()?;
let tile34 = first / 4;
let type_idx = 37 + tile34 as u16;
Some([type_idx, 2, 2, 3])
}
ActionType::Kakan => {
let tile = action
.tile
.or_else(|| action.consume_tiles.first().copied())?;
let k37 = tile_id_to_kan37(tile as u32);
let type_idx = 71 + k37 as u16;
Some([type_idx, 2, 2, 3])
}
ActionType::Tsumo => Some([108, 2, 2, 3]),
ActionType::KyushuKyuhai => Some([109, 2, 2, 3]),
ActionType::Pass => Some([110, 2, 2, 3]),
ActionType::Chi => {
let called_tile = action.tile?;
let consumed = &action.consume_tiles;
if consumed.len() < 2 {
return None;
}
let chi_enc = encode_chi(consumed, called_tile);
let type_idx = 111 + chi_enc;
let target = self.find_last_discard_actor()?;
let rel = relative_from(pid, target);
Some([type_idx, 2, 2, rel as u16])
}
ActionType::Pon => {
let called_tile = action.tile?;
let consumed = &action.consume_tiles;
if consumed.len() < 2 {
return None;
}
let pon_enc = encode_pon(consumed, called_tile);
let type_idx = 201 + pon_enc;
let target = self.find_last_discard_actor()?;
let rel = relative_from(pid, target);
Some([type_idx, 2, 2, rel as u16])
}
ActionType::Daiminkan => {
let tile = action.tile?;
let k37 = tile_id_to_kan37(tile as u32);
let type_idx = 241 + k37 as u16;
let target = self.find_last_discard_actor()?;
let rel = relative_from(pid, target);
Some([type_idx, 2, 2, rel as u16])
}
ActionType::Ron => {
let target = self.find_last_discard_actor()?;
let rel = relative_from(pid, target);
Some([278, 2, 2, rel as u16])
}
ActionType::Kita => None, }
}
fn is_tsumogiri_candidate(&self, tile: u8) -> bool {
if let Some(drawn) = self.get_drawn_tile() {
drawn == tile
} else {
false
}
}
fn find_last_discard_actor(&self) -> Option<u8> {
for event_str in self.events.iter().rev() {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(event_str) {
let event_type = v["type"].as_str().unwrap_or("");
if event_type == "dahai" || event_type == "kakan" {
return v["actor"].as_u64().map(|a| a as u8);
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tile_id_to_kan37() {
assert_eq!(tile_id_to_kan37(16), 0); assert_eq!(tile_id_to_kan37(52), 10); assert_eq!(tile_id_to_kan37(88), 20);
assert_eq!(tile_id_to_kan37(0), 1);
assert_eq!(tile_id_to_kan37(3), 1);
assert_eq!(tile_id_to_kan37(17), 5);
assert_eq!(tile_id_to_kan37(32), 9);
assert_eq!(tile_id_to_kan37(36), 11);
assert_eq!(tile_id_to_kan37(72), 21);
assert_eq!(tile_id_to_kan37(108), 30);
assert_eq!(tile_id_to_kan37(132), 36);
}
#[test]
fn test_relative_from() {
assert_eq!(relative_from(0, 3), 2);
assert_eq!(relative_from(0, 1), 0);
assert_eq!(relative_from(0, 2), 1);
assert_eq!(relative_from(2, 3), 0);
}
#[test]
fn test_encode_chi_basic() {
let consumed = [4u8, 8]; let called = 0u8; let enc = encode_chi(&consumed, called);
assert_eq!(enc, 0);
let consumed = [0u8, 8]; let called = 4u8; let enc = encode_chi(&consumed, called);
assert_eq!(enc, 1); }
#[test]
fn test_encode_pon_honor() {
let consumed = [109u8, 110];
let called = 108u8;
let enc = encode_pon(&consumed, called);
assert_eq!(enc, 33); }
#[test]
fn test_encode_pon_five_red() {
let consumed = [16u8, 17]; let called = 18u8; let enc = encode_pon(&consumed, called);
assert_eq!(enc, 5);
let consumed = [17u8, 18];
let called = 16u8; let enc = encode_pon(&consumed, called);
assert_eq!(enc, 6); }
#[test]
fn test_sparse_vocab_bounds() {
assert!(441 < SPARSE_VOCAB_SIZE as u16);
assert!(83 + 4 * 37 + 36 < SPARSE_VOCAB_SIZE as u16);
assert!(268 + 135 < SPARSE_VOCAB_SIZE as u16);
assert!(404 + 36 < SPARSE_VOCAB_SIZE as u16);
}
#[test]
fn test_progression_type_bounds() {
let prog_type_max: u16 = 277;
assert!(37 <= prog_type_max); assert!(38 + 89 <= prog_type_max); assert!(128 + 39 <= prog_type_max); assert!(168 + 36 <= prog_type_max); assert!(205 + 33 <= prog_type_max); assert!(239 + 36 <= prog_type_max); }
#[test]
fn test_candidate_type_bounds() {
let cand_type_max: u16 = 280;
assert!(36 < cand_type_max); assert!(37 + 33 < cand_type_max); assert!(71 + 36 < cand_type_max); assert!(108 < cand_type_max); assert!(109 < cand_type_max); assert!(110 < cand_type_max); assert!(111 + 89 < cand_type_max); assert!(201 + 39 < cand_type_max); assert!(241 + 36 < cand_type_max); assert!(278 < cand_type_max); }
}