use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
path::{Path, PathBuf},
};
use toml_edit::{Array, Document, DocumentMut, InlineTable, Table};
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct PyProjectConfig {
pub project: Project,
pub dependency_groups: Option<DependencyGroups>,
pub build_system: Option<BuildSystem>,
pub tool: Option<Tool>,
}
impl Default for PyProjectConfig {
fn default() -> Self {
Self {
project: Project::default(),
dependency_groups: Some(DependencyGroups::default()),
build_system: Some(BuildSystem::default()),
tool: Some(Tool::default()),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum DependencyGroupItem {
String(String),
IncludeGroup {
#[serde(rename = "include-group")]
include_group: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct DependencyGroups {
#[serde(flatten)]
pub groups: HashMap<String, Vec<DependencyGroupItem>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Project {
pub name: String,
pub version: String,
pub description: String,
pub requires_python: String,
pub dependencies: Vec<String>,
pub authors: Option<Vec<Author>>,
pub readme: Option<String>,
pub urls: Option<HashMap<String, String>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Author {
pub name: String,
pub email: String,
}
impl Default for Project {
fn default() -> Self {
Self {
name: String::from("awesome-bot"),
version: String::from("0.1.0"),
description: String::from("a nonebot project"),
requires_python: String::from(">=3.10"),
dependencies: vec![],
authors: Some(vec![]),
readme: Some(String::from("README.md")),
urls: None,
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct Tool {
pub nonebot: Option<Nonebot>,
}
impl Default for Tool {
fn default() -> Self {
Self {
nonebot: Some(Nonebot::default()),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Nonebot {
pub adapters: Option<Vec<Adapter>>,
pub plugins: Option<Vec<String>>,
pub plugin_dirs: Option<Vec<String>>,
pub builtin_plugins: Option<Vec<String>>,
}
impl Default for Nonebot {
fn default() -> Self {
Self {
adapters: Some(vec![]),
plugins: Some(vec![]),
plugin_dirs: Some(vec![]),
builtin_plugins: Some(vec![]),
}
}
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, Eq, PartialEq, Hash)]
pub struct Adapter {
pub name: String,
pub module_name: String,
}
impl Adapter {
pub fn alias(&self) -> String {
let camel_case = self
.module_name
.trim_start_matches("nonebot.adapters.")
.split('.')
.map(|part| {
let (first, rest) = part.split_at(1);
first.to_ascii_uppercase() + rest
})
.collect::<String>();
camel_case + "Adapter"
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct BuildSystem {
pub requires: Vec<String>,
pub build_backend: String,
}
impl Default for BuildSystem {
fn default() -> Self {
Self {
requires: vec!["uv_build>=0.9.0,<0.10.0".to_string()],
build_backend: "uv_build".to_string(),
}
}
}
impl PyProjectConfig {
pub fn parse(work_dir: Option<&Path>) -> Result<Self> {
let toml_path = if let Some(work_dir) = work_dir {
work_dir.join("pyproject.toml")
} else {
Path::new("pyproject.toml").to_path_buf()
};
if !toml_path.exists() {
anyhow::bail!("{} does not exist", toml_path.display());
}
let content =
std::fs::read_to_string(toml_path).context("Failed to read pyproject.toml")?;
Self::parse_from_str(&content)
}
pub fn parse_from_str(content: &str) -> Result<Self> {
toml::from_str(content).context("Failed to parse pyproject.toml to PyProjectConfig")
}
#[allow(unused)]
pub fn parse_current_dir() -> Result<Self> {
Self::parse(None)
}
pub fn nonebot(&self) -> Option<&Nonebot> {
self.tool.as_ref().and_then(|tool| tool.nonebot.as_ref())
}
}
#[derive(Debug, Clone)]
pub struct NbTomlEditor {
toml_path: PathBuf,
doc_mut: DocumentMut,
}
impl NbTomlEditor {
pub fn with_str(content: &str, save_path: &Path) -> Result<Self> {
let toml_path = save_path.to_path_buf();
let doc = Document::parse(content).context("Failed to parse pyproject.toml")?;
let doc_mut = doc.into_mut();
Ok(Self { toml_path, doc_mut })
}
pub fn with_work_dir(work_dir: Option<&Path>) -> Result<Self> {
let toml_path = if let Some(work_dir) = work_dir {
work_dir.join("pyproject.toml")
} else {
Path::new("pyproject.toml").to_path_buf()
};
let mut content =
std::fs::read_to_string(toml_path.clone()).context("Failed to read pyproject.toml")?;
if !content.contains("[tool.nonebot]") {
content.push_str(
format!(
include_str!("cli/templates/pyproject/tool_nonebot"),
"", "", ""
)
.as_str(),
);
}
Self::with_str(&content, &toml_path)
}
fn nonebot_table_mut(&mut self) -> Result<&mut Table> {
self.doc_mut["tool"]["nonebot"]
.as_table_mut()
.context("tool.nonebot is not a table")
}
fn adapters_array_mut(&mut self) -> Result<&mut Array> {
let table = self.nonebot_table_mut()?;
let item = table
.get_mut("adapters")
.context("adapters not found in tool.nonebot")?;
item.as_array_mut().context("adapters is not an array")
}
fn plugins_array_mut(&mut self) -> Result<&mut Array> {
let table = self.nonebot_table_mut()?;
let item = table
.get_mut("plugins")
.context("plugins not found in tool.nonebot")?;
item.as_array_mut().context("plugins is not an array")
}
fn save(&self) -> Result<()> {
std::fs::write(self.toml_path.clone(), self.doc_mut.to_string())?;
Ok(())
}
fn fmt_toml_array(array: &mut toml_edit::Array) {
array.iter_mut().for_each(|a| {
let decor_mut = a.decor_mut();
decor_mut.set_prefix("\n ");
decor_mut.set_suffix("");
});
if let Some(last) = array.iter_mut().last() {
last.decor_mut().set_suffix("\n");
}
}
pub fn add_adapters(&mut self, adapters: Vec<Adapter>) -> Result<()> {
let adapters = adapters.into_iter().collect::<HashSet<Adapter>>();
let adapters_arr_mut = self.adapters_array_mut()?;
for adapter in adapters {
let mut inline_table = InlineTable::new();
inline_table.insert("name", adapter.name.into());
inline_table.insert("module_name", adapter.module_name.into());
adapters_arr_mut.push(inline_table);
}
Self::fmt_toml_array(adapters_arr_mut);
self.save()
}
pub fn remove_adapters(&mut self, adapter_names: Vec<&str>) -> Result<()> {
let adapters_arr_mut = self.adapters_array_mut()?;
adapters_arr_mut.retain(|a| {
a.as_inline_table()
.and_then(|table| table.get("name"))
.and_then(|v| v.as_str())
.is_none_or(|name| !adapter_names.contains(&name))
});
self.save()
}
pub fn add_plugins(&mut self, plugins: Vec<&str>) -> Result<()> {
let mut plugins = plugins.into_iter().collect::<HashSet<&str>>();
let plugins_arr_mut = self.plugins_array_mut()?;
let plugin_names = plugins_arr_mut
.iter()
.filter_map(|p| p.as_str())
.collect::<Vec<&str>>();
plugins.retain(|p| !plugin_names.contains(p));
plugins_arr_mut.extend(plugins);
Self::fmt_toml_array(plugins_arr_mut);
self.save()
}
pub fn remove_plugins(&mut self, plugins: Vec<&str>) -> Result<()> {
let plugins_arr_mut = self.plugins_array_mut()?;
plugins_arr_mut.retain(|p| {
if let Some(name) = p.as_str() {
!plugins.contains(&name)
} else {
true
}
});
Self::fmt_toml_array(plugins_arr_mut);
self.save()
}
pub fn reset_plugins(&mut self, plugins: Vec<&str>) -> Result<()> {
let plugins_arr_mut = self.plugins_array_mut()?;
plugins_arr_mut.clear();
plugins_arr_mut.extend(plugins);
Self::fmt_toml_array(plugins_arr_mut);
self.save()
}
#[allow(unused)]
pub fn reset_adapters(&mut self, adapters: Vec<Adapter>) -> Result<()> {
let adapters_arr_mut = self.adapters_array_mut()?;
adapters_arr_mut.clear();
adapters_arr_mut.extend(adapters.into_iter().map(|adapter| {
let mut inline_table = InlineTable::new();
inline_table.insert("name", adapter.name.into());
inline_table.insert("module_name", adapter.module_name.into());
inline_table
}));
self.save()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dependency_groups_with_include() {
let toml_content = r#"
[project]
name = "test-project"
version = "0.1.0"
description = "Test project"
requires-python = ">=3.10"
dependencies = []
[dependency-groups]
test = ["pytest>=7.0", "coverage"]
typing = ["mypy", "types-requests"]
dev = [
{ include-group = "test" },
{ include-group = "typing" },
"ruff"
]
"#;
let pyproject =
PyProjectConfig::parse_from_str(toml_content).expect("Failed to parse test TOML");
let dep_groups = pyproject
.dependency_groups
.expect("dependency_groups should be present");
let test_group = dep_groups
.groups
.get("test")
.expect("test group should be present");
assert_eq!(test_group.len(), 2);
assert!(matches!(&test_group[0], DependencyGroupItem::String(s) if s == "pytest>=7.0"));
assert!(matches!(&test_group[1], DependencyGroupItem::String(s) if s == "coverage"));
let dev_group = dep_groups
.groups
.get("dev")
.expect("dev group should be present");
assert_eq!(dev_group.len(), 3);
assert!(
matches!(&dev_group[0], DependencyGroupItem::IncludeGroup { include_group } if include_group == "test")
);
assert!(
matches!(&dev_group[1], DependencyGroupItem::IncludeGroup { include_group } if include_group == "typing")
);
assert!(matches!(&dev_group[2], DependencyGroupItem::String(s) if s == "ruff"));
}
#[test]
fn test_dependency_groups_serialization() {
let mut pyproject = PyProjectConfig::default();
let mut groups = std::collections::HashMap::new();
groups.insert(
"test".to_string(),
vec![
DependencyGroupItem::String("pytest>=7.0".to_string()),
DependencyGroupItem::String("coverage".to_string()),
],
);
groups.insert(
"dev".to_string(),
vec![
DependencyGroupItem::IncludeGroup {
include_group: "test".to_string(),
},
DependencyGroupItem::String("ruff".to_string()),
],
);
pyproject.dependency_groups = Some(DependencyGroups { groups });
let toml_str = toml::to_string(&pyproject).expect("Failed to serialize pyproject");
println!("Serialized TOML:\n{}", toml_str);
assert!(toml_str.contains("[dependency-groups]"));
assert!(toml_str.contains("test = ["));
assert!(toml_str.contains("\"pytest>=7.0\""));
assert!(toml_str.contains("dev = ["));
assert!(toml_str.contains("include-group = \"test\""));
let parsed: PyProjectConfig =
toml::from_str(&toml_str).expect("Failed to parse serialized TOML");
let parsed_groups = parsed
.dependency_groups
.expect("dependency_groups should be present");
assert_eq!(parsed_groups.groups.len(), 2);
}
#[test]
fn test_dev_group_includes_test_first() {
let mut pyproject = PyProjectConfig::default();
let dev_deps = vec!["ruff>=0.14.8".to_string(), "pre-commit>=4.3.0".to_string()];
let test_group_items: Vec<DependencyGroupItem> = vec![
DependencyGroupItem::String("pytest>=7.0".to_string()),
DependencyGroupItem::String("coverage".to_string()),
];
let mut dev_group_items: Vec<DependencyGroupItem> = vec![
DependencyGroupItem::IncludeGroup {
include_group: String::from("test"),
},
];
dev_group_items.extend(dev_deps.into_iter().map(DependencyGroupItem::String));
let dep_groups = pyproject
.dependency_groups
.as_mut()
.expect("dependency_groups should be present");
dep_groups
.groups
.insert("test".to_string(), test_group_items);
dep_groups.groups.insert("dev".to_string(), dev_group_items);
let dev_group = dep_groups
.groups
.get("dev")
.expect("dev group should be present");
assert_eq!(dev_group.len(), 3);
assert!(
matches!(&dev_group[0], DependencyGroupItem::IncludeGroup { include_group } if include_group == "test")
);
assert!(matches!(&dev_group[1], DependencyGroupItem::String(s) if s == "ruff>=0.14.8"));
assert!(
matches!(&dev_group[2], DependencyGroupItem::String(s) if s == "pre-commit>=4.3.0")
);
let toml_str = toml::to_string(&pyproject).expect("Failed to serialize pyproject");
let dev_line_start = toml_str.find("dev = [").expect("dev group not found");
let include_pos = toml_str[dev_line_start..]
.find("include-group")
.expect("include-group not found");
let ruff_pos = toml_str[dev_line_start..]
.find("ruff")
.expect("ruff not found");
assert!(
include_pos < ruff_pos,
"include-group should come before ruff in serialized TOML"
);
}
}