use std::collections::BTreeMap;
use std::fs;
use std::path::Path;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::DatasetSpec;
use crate::bench::Bench;
pub const SIMPLE_ROLES_SPEC: DatasetSpec = DatasetSpec {
bench: Bench::MembenchSimpleRoles,
filename: "simple.json",
url: "https://huggingface.co/datasets/import-myself/Membench/resolve/main/FirstAgent/simple.json",
sha256: "",
bytes: 4 * 1024 * 1024,
};
pub const HIGHLEVEL_MOVIE_SPEC: DatasetSpec = DatasetSpec {
bench: Bench::MembenchHighlevelMovie,
filename: "highlevel.json",
url: "https://huggingface.co/datasets/import-myself/Membench/resolve/main/FirstAgent/highlevel.json",
sha256: "",
bytes: 6 * 1024 * 1024,
};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Turn {
#[serde(default, alias = "user")]
pub user_message: String,
#[serde(default, alias = "assistant")]
pub assistant_message: String,
#[serde(default)]
pub sid: Option<i64>,
#[serde(default)]
pub mid: Option<i64>,
#[serde(default)]
pub time: String,
#[serde(default)]
pub place: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Qa {
#[serde(default)]
pub question: String,
#[serde(default)]
pub target_step_id: Vec<Value>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Item {
#[serde(default)]
pub tid: i64,
#[serde(default, rename = "message_list")]
pub message_list: Value,
#[serde(rename = "QA")]
pub qa: Qa,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub category: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub topic: String,
}
#[must_use]
pub fn flatten_turns(message_list: &Value) -> Vec<(usize, usize, usize, Turn)> {
let arr = match message_list.as_array() {
Some(a) => a,
None => return Vec::new(),
};
if arr.is_empty() {
return Vec::new();
}
let sessions: Vec<&[Value]> = if arr
.first()
.map(serde_json::Value::is_object)
.unwrap_or(false)
{
vec![arr.as_slice()]
} else {
arr.iter()
.filter_map(|v| v.as_array().map(std::vec::Vec::as_slice))
.collect()
};
let mut flat = Vec::new();
let mut g = 0usize;
for (s_idx, sess) in sessions.iter().enumerate() {
for (t_idx, raw) in sess.iter().enumerate() {
if let Ok(turn) = serde_json::from_value::<Turn>(raw.clone()) {
flat.push((g, s_idx, t_idx, turn));
g += 1;
}
}
}
flat
}
#[must_use]
pub fn render_turn(turn: &Turn) -> String {
let user = turn.user_message.trim();
let prefix = if turn.time.is_empty() {
String::new()
} else {
format!("[{}] ", turn.time)
};
let suffix = if turn.place.is_empty() {
String::new()
} else {
format!(" (@{})", turn.place)
};
format!("{prefix}{user}{suffix}")
}
#[must_use]
pub fn turn_sid(g: usize, turn: &Turn) -> i64 {
turn.sid
.or(turn.mid)
.unwrap_or_else(|| i64::try_from(g).unwrap_or(i64::MAX))
}
pub fn load_filtered(path: &Path, category: &str, topic: Option<&str>) -> Result<Vec<Item>> {
let bytes = fs::read(path).with_context(|| format!("reading {}", path.display()))?;
let by_topic: BTreeMap<String, Vec<Item>> =
serde_json::from_slice(&bytes).with_context(|| format!("parsing {}", path.display()))?;
let mut out = Vec::new();
for (k, items) in by_topic {
if let Some(want) = topic
&& k != want
{
continue;
}
for mut it in items {
it.category = category.to_string();
it.topic = k.clone();
out.push(it);
}
}
Ok(out)
}