use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::path::Path;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LockError {
#[error("failed to read lockfile: {0}")]
ReadError(#[from] std::io::Error),
#[error("failed to parse lockfile: {0}")]
ParseError(#[from] toml::de::Error),
#[error("failed to serialize lockfile: {0}")]
SerializeError(#[from] toml::ser::Error),
#[error("lockfile version mismatch: expected {expected}, found {found}")]
VersionMismatch { expected: u32, found: u32 },
}
pub const LOCK_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct LockedPackage {
pub name: String,
pub version: String,
#[serde(default = "default_source")]
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub hash: Option<String>,
}
fn default_source() -> String {
"hackage".to_string()
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct LockedToolchain {
pub ghc: Option<String>,
pub cabal: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct LockedPlan {
#[serde(skip_serializing_if = "Option::is_none")]
pub compiler_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub platform: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub index_state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub snapshot: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hash: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct WorkspacePackageInfo {
pub name: String,
pub version: String,
pub path: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct LockedWorkspace {
#[serde(default)]
pub is_workspace: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub packages: Vec<WorkspacePackageInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Lockfile {
pub version: u32,
pub created_at: DateTime<Utc>,
#[serde(default)]
pub toolchain: LockedToolchain,
#[serde(default)]
pub plan: LockedPlan,
#[serde(default, skip_serializing_if = "is_default_workspace")]
pub workspace: LockedWorkspace,
#[serde(default)]
pub packages: Vec<LockedPackage>,
}
fn is_default_workspace(w: &LockedWorkspace) -> bool {
!w.is_workspace && w.packages.is_empty()
}
impl Default for Lockfile {
fn default() -> Self {
Self::new()
}
}
impl Lockfile {
pub fn new() -> Self {
Self {
version: LOCK_VERSION,
created_at: Utc::now(),
toolchain: LockedToolchain::default(),
plan: LockedPlan::default(),
workspace: LockedWorkspace::default(),
packages: Vec::new(),
}
}
pub fn parse(s: &str) -> Result<Self, LockError> {
let lock: Lockfile = toml::from_str(s)?;
if lock.version != LOCK_VERSION {
return Err(LockError::VersionMismatch {
expected: LOCK_VERSION,
found: lock.version,
});
}
Ok(lock)
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, LockError> {
let content = std::fs::read_to_string(path)?;
Self::parse(&content)
}
pub fn to_string(&self) -> Result<String, LockError> {
Ok(toml::to_string_pretty(self)?)
}
pub fn to_file(&self, path: impl AsRef<Path>) -> Result<(), LockError> {
let content = self.to_string()?;
std::fs::write(path, content)?;
Ok(())
}
pub fn fingerprint(&self) -> String {
let mut hasher = Sha256::new();
if let Some(ref ghc) = self.toolchain.ghc {
hasher.update(format!("ghc:{}", ghc));
}
if let Some(ref cabal) = self.toolchain.cabal {
hasher.update(format!("cabal:{}", cabal));
}
if let Some(ref platform) = self.plan.platform {
hasher.update(format!("platform:{}", platform));
}
if let Some(ref index_state) = self.plan.index_state {
hasher.update(format!("index:{}", index_state));
}
if let Some(ref snapshot) = self.plan.snapshot {
hasher.update(format!("snapshot:{}", snapshot));
}
if self.workspace.is_workspace {
hasher.update("workspace:true");
let mut workspace_pkgs: Vec<_> = self.workspace.packages.iter().collect();
workspace_pkgs.sort_by(|a, b| a.name.cmp(&b.name));
for pkg in workspace_pkgs {
hasher.update(format!("local:{}@{}:{}", pkg.name, pkg.version, pkg.path));
}
}
let mut packages: Vec<_> = self.packages.iter().collect();
packages.sort_by(|a, b| a.name.cmp(&b.name));
for pkg in packages {
hasher.update(format!("{}@{}", pkg.name, pkg.version));
}
let result = hasher.finalize();
format!("sha256:{}", hex::encode(result))
}
pub fn add_package(&mut self, pkg: LockedPackage) {
self.packages.retain(|p| p.name != pkg.name);
self.packages.push(pkg);
}
pub fn set_toolchain(&mut self, ghc: Option<String>, cabal: Option<String>) {
self.toolchain.ghc = ghc;
self.toolchain.cabal = cabal;
}
pub fn set_snapshot(&mut self, snapshot: Option<String>) {
self.plan.snapshot = snapshot;
}
pub fn set_workspace(&mut self, packages: Vec<WorkspacePackageInfo>) {
self.workspace.is_workspace = !packages.is_empty();
self.workspace.packages = packages;
}
pub fn is_workspace(&self) -> bool {
self.workspace.is_workspace
}
pub fn workspace_package_names(&self) -> Vec<&str> {
self.workspace
.packages
.iter()
.map(|p| p.name.as_str())
.collect()
}
}
pub fn parse_freeze_file(content: &str) -> Vec<LockedPackage> {
let mut packages = Vec::new();
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with("--") {
continue;
}
let constraint = line
.strip_prefix("constraints:")
.or(Some(line))
.map(|s| s.trim().trim_end_matches(','));
if let Some(constraint) = constraint
&& let Some((name, version)) = parse_constraint(constraint)
{
packages.push(LockedPackage {
name,
version,
source: "hackage".to_string(),
hash: None,
});
}
}
packages
}
fn parse_constraint(s: &str) -> Option<(String, String)> {
let s = s.trim().trim_end_matches(',');
let parts: Vec<&str> = s.split(" ==").collect();
if parts.len() == 2 {
let name = parts[0].trim();
let version = parts[1].trim();
if !name.is_empty() && !version.is_empty() && !name.starts_with("any.") {
return Some((name.to_string(), version.to_string()));
}
}
None
}
mod hex {
pub fn encode(bytes: impl AsRef<[u8]>) -> String {
bytes
.as_ref()
.iter()
.map(|b| format!("{:02x}", b))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lockfile_roundtrip() {
let mut lock = Lockfile::new();
lock.set_toolchain(Some("9.8.2".to_string()), Some("3.12.1.0".to_string()));
lock.add_package(LockedPackage {
name: "text".to_string(),
version: "2.1.1".to_string(),
source: "hackage".to_string(),
hash: None,
});
let toml = lock.to_string().unwrap();
let parsed = Lockfile::parse(&toml).unwrap();
assert_eq!(parsed.toolchain.ghc, Some("9.8.2".to_string()));
assert_eq!(parsed.packages.len(), 1);
assert_eq!(parsed.packages[0].name, "text");
}
#[test]
fn test_workspace_lockfile_roundtrip() {
let mut lock = Lockfile::new();
lock.set_toolchain(Some("9.8.2".to_string()), Some("3.12.1.0".to_string()));
lock.set_workspace(vec![
WorkspacePackageInfo {
name: "mylib".to_string(),
version: "0.1.0".to_string(),
path: "packages/mylib".to_string(),
},
WorkspacePackageInfo {
name: "myapp".to_string(),
version: "0.1.0".to_string(),
path: "packages/myapp".to_string(),
},
]);
lock.add_package(LockedPackage {
name: "text".to_string(),
version: "2.1.1".to_string(),
source: "hackage".to_string(),
hash: None,
});
let toml = lock.to_string().unwrap();
let parsed = Lockfile::parse(&toml).unwrap();
assert!(parsed.is_workspace());
assert_eq!(parsed.workspace.packages.len(), 2);
assert_eq!(parsed.packages.len(), 1);
let names = parsed.workspace_package_names();
assert!(names.contains(&"mylib"));
assert!(names.contains(&"myapp"));
}
#[test]
fn test_workspace_fingerprint_includes_packages() {
let mut lock1 = Lockfile::new();
lock1.set_workspace(vec![WorkspacePackageInfo {
name: "pkg1".to_string(),
version: "0.1.0".to_string(),
path: "packages/pkg1".to_string(),
}]);
let mut lock2 = Lockfile::new();
lock2.set_workspace(vec![WorkspacePackageInfo {
name: "pkg2".to_string(),
version: "0.1.0".to_string(),
path: "packages/pkg2".to_string(),
}]);
assert_ne!(lock1.fingerprint(), lock2.fingerprint());
}
#[test]
fn test_parse_constraint() {
assert_eq!(
parse_constraint("text ==2.1.1"),
Some(("text".to_string(), "2.1.1".to_string()))
);
assert_eq!(
parse_constraint(" aeson ==2.2.0.0,"),
Some(("aeson".to_string(), "2.2.0.0".to_string()))
);
assert_eq!(parse_constraint("any.base ==4.19"), None);
}
}