use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::OnceLock;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Scale {
pub name: String,
pub intervals: Option<Vec<u8>>,
pub intervals_ascending: Option<Vec<u8>>,
pub intervals_descending: Option<Vec<u8>>,
pub notes: Option<Vec<u8>>,
pub notes_ascending: Option<Vec<u8>>,
pub notes_descending: Option<Vec<u8>>,
pub origin: Option<String>,
}
#[derive(Debug, Error)]
pub enum ScaleOmnibusError {
#[error("Scale not found: {0}")]
ScaleNotFoundError(String),
}
static SCALES: OnceLock<HashMap<String, Scale>> = OnceLock::new();
fn load_scales() -> &'static HashMap<String, Scale> {
SCALES.get_or_init(|| {
const SCALES_YAML: &str = include_str!("../data/scales.yaml");
let scales: Vec<serde_yaml::Value> =
serde_yaml::from_str(SCALES_YAML).expect("Invalid YAML format");
let mut scale_map: HashMap<String, Scale> = HashMap::new();
let mut name_counts: HashMap<String, usize> = HashMap::new();
for item in scales {
if let Some(name) = item.get("name").and_then(|n| n.as_str()) {
let key = name.to_lowercase();
let parse_optional_vec_u8 = |key: &str| -> Option<Vec<u8>> {
item.get(key)?
.as_sequence()?
.iter()
.filter_map(|n| n.as_u64().map(|n| n as u8))
.collect::<Vec<u8>>()
.into()
};
let mut scale = Scale {
name: name.to_string(),
intervals: parse_optional_vec_u8("intervals"),
intervals_ascending: parse_optional_vec_u8("intervals_ascending"),
intervals_descending: parse_optional_vec_u8("intervals_descending"),
notes: parse_optional_vec_u8("notes"),
notes_ascending: parse_optional_vec_u8("notes_ascending"),
notes_descending: parse_optional_vec_u8("notes_descending"),
origin: item
.get("origin")
.and_then(|o| o.as_str().map(|s| s.to_string())),
};
let mut unique_key = key.clone();
if let Entry::Occupied(mut count) = name_counts.entry(key.clone()) {
*count.get_mut() += 1;
let suffix = format!(" ({})", count.get());
scale.name = format!("{}{}", name, suffix);
unique_key = format!("{}{}", key, suffix);
} else {
name_counts.insert(key.clone(), 0);
}
scale_map.insert(unique_key, scale);
}
}
scale_map
})
}
pub fn get_scale(name: &str) -> Result<&Scale, ScaleOmnibusError> {
let scales = load_scales();
scales
.get(&name.to_lowercase())
.ok_or_else(|| ScaleOmnibusError::ScaleNotFoundError(name.to_string()))
}
pub fn get_scale_names() -> Vec<String> {
let scales = load_scales();
scales.keys().cloned().collect()
}
pub fn filter_scales<F>(filter: F) -> Result<Vec<Scale>, ScaleOmnibusError>
where
F: Fn(&Scale) -> bool,
{
let scales = load_scales();
Ok(scales
.values()
.filter(|&scale| filter(scale))
.cloned()
.collect::<Vec<Scale>>())
}
pub fn find_scales_with_intervals_greater_than(
min_intervals: usize,
) -> Result<Vec<Scale>, ScaleOmnibusError> {
filter_scales(|scale| {
scale
.intervals
.as_ref()
.map_or(false, |intervals| intervals.len() > min_intervals)
})
}
pub fn find_scales_by_origin(origin: &str) -> Result<Vec<Scale>, ScaleOmnibusError> {
filter_scales(|scale| {
scale
.origin
.as_ref()
.map_or(false, |o| o.to_lowercase() == origin.to_lowercase())
})
}
pub fn find_scales_with_up_down_intervals() -> Result<Vec<Scale>, ScaleOmnibusError> {
filter_scales(|scale| scale.intervals_ascending.is_some())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_scale_names() -> Result<(), Box<dyn std::error::Error>> {
let names = get_scale_names();
let n = names.len();
assert!(
names.len() > 1000,
"There are some missing scales, because only {n} scales were found."
);
assert!(names.contains(&"major".to_string()));
assert!(names.contains(&"superlocrian #6".to_string()));
Ok(())
}
#[test]
fn test_get_scale() -> Result<(), Box<dyn std::error::Error>> {
let bebop_major = get_scale("BeBop majoR")?; assert_eq!(bebop_major.name, "Bebop Major");
assert_eq!(bebop_major.intervals, Some(vec![2, 2, 1, 2, 1, 1, 2, 1]));
assert_eq!(bebop_major.notes, Some(vec![0, 2, 4, 5, 7, 8, 9, 11]));
assert_eq!(bebop_major.notes_ascending, None);
assert_eq!(bebop_major.intervals_ascending, None);
assert_eq!(bebop_major.notes_descending, None);
assert_eq!(bebop_major.intervals_descending, None);
Ok(())
}
#[test]
fn test_edge_name() -> Result<(), Box<dyn std::error::Error>> {
let superlocrian = get_scale("superlocrian #6")?;
assert_eq!(superlocrian.name, "Superlocrian #6");
assert_eq!(superlocrian.intervals, Some(vec![1, 2, 1, 2, 3, 1, 2]));
assert_eq!(superlocrian.notes, Some(vec![0, 1, 3, 4, 6, 9, 10]));
assert_eq!(superlocrian.notes_ascending, None);
assert_eq!(superlocrian.intervals_ascending, None);
assert_eq!(superlocrian.notes_descending, None);
assert_eq!(superlocrian.intervals_descending, None);
Ok(())
}
#[test]
fn test_conflict_name() -> Result<(), Box<dyn std::error::Error>> {
let messiaen_2nd_one = get_scale("Messiaen 2nd Mode")?;
let messiaen_2nd_two = get_scale("Messiaen 2nd Mode (1)")?;
assert_eq!(messiaen_2nd_one.name, "Messiaen 2nd Mode");
assert_eq!(messiaen_2nd_two.name, "Messiaen 2nd Mode (1)");
Ok(())
}
#[test]
fn test_notes_ascending() -> Result<(), Box<dyn std::error::Error>> {
let enigmatic = get_scale("Enigmatic")?;
assert_eq!(enigmatic.name, "Enigmatic");
assert_eq!(enigmatic.intervals, None);
assert_eq!(
enigmatic.intervals_ascending,
Some(vec![1, 3, 2, 2, 2, 1, 1])
);
assert_eq!(
enigmatic.intervals_descending,
Some(vec![1, 3, 1, 3, 2, 1, 1])
);
assert_eq!(enigmatic.notes_ascending, Some(vec![0, 1, 4, 6, 8, 10, 11]));
assert_eq!(
enigmatic.notes_descending,
Some(vec![0, 1, 4, 5, 8, 10, 11])
);
assert_eq!(enigmatic.notes, None);
Ok(())
}
#[test]
fn test_filter_scales_by_name() -> Result<(), Box<dyn std::error::Error>> {
let filtered_scales = filter_scales(|scale| scale.name.to_lowercase().contains("major"))?;
let bebop_major = get_scale("Bebop major")?;
let aeolian_major = get_scale("Aeolian Major")?;
let major_pentatonic = get_scale("Major Pentatonic b7 #9")?;
println!("{filtered_scales:?}");
assert!(filtered_scales.contains(&bebop_major));
assert!(filtered_scales.contains(&aeolian_major));
assert!(filtered_scales.contains(&major_pentatonic));
Ok(())
}
#[test]
fn test_filter_scales_by_number_of_intervals() -> Result<(), Box<dyn std::error::Error>> {
let filtered_scales = filter_scales(|scale| {
scale
.intervals
.as_ref()
.map_or(false, |intervals| intervals.len() == 12)
})?;
assert_eq!(
filtered_scales.len(),
1,
"There should only be one scale with 12 intervals (chromatic)."
);
assert_eq!(&filtered_scales[0], get_scale("Chromatic")?);
Ok(())
}
#[test]
fn test_find_scales_with_intervals_greater_than() -> Result<(), Box<dyn std::error::Error>> {
let filtered_scales = find_scales_with_intervals_greater_than(5)?;
assert!(
!filtered_scales.is_empty(),
"No scales found with >5 intervals"
);
for scale in &filtered_scales {
println!("Scale with >5 intervals: {:?}", scale.name);
}
Ok(())
}
#[test]
fn test_find_scales_by_origin() -> Result<(), Box<dyn std::error::Error>> {
let filtered_scales = find_scales_by_origin("Egypt")?;
assert!(
!filtered_scales.is_empty(),
"No scales found originating from Egypt"
);
for scale in &filtered_scales {
println!("Scale from Egypt: {:?}", scale.name);
}
Ok(())
}
#[test]
fn test_find_scales_with_up_down_intervals() -> Result<(), Box<dyn std::error::Error>> {
let filtered_scales = find_scales_with_up_down_intervals()?;
assert!(
!filtered_scales.is_empty(),
"No scales found with different ascending and descending intervals"
);
for scale in &filtered_scales {
assert_eq!(scale.intervals, None);
assert_eq!(scale.notes, None);
}
Ok(())
}
}