import argparse
import glob
import gzip
import os
import pickle
import shutil
import subprocess
import sys
from concurrent.futures import ProcessPoolExecutor
DATA_ROOT = "/Volumes/data/Shared/ML"
def load_annotations(gz_path: str) -> list[tuple[float, float]]:
with gzip.open(gz_path, "rb") as f:
data = pickle.load(f)
try:
annots = data.annotations["annot"]
except (AttributeError, KeyError, TypeError):
if isinstance(data, dict):
annots = data.get("annotations", {}).get("annot", {})
else:
return []
for level in ("notes", "words", "lines", "paragraphs"):
entries = annots.get(level, [])
if entries:
break
else:
return []
segments: list[tuple[float, float]] = []
for entry in entries:
try:
t = entry["time"]
start = float(t[0])
end = float(t[1])
if end > start:
segments.append((start, end))
except (KeyError, TypeError, IndexError):
continue
return segments
def load_dali_info(annotations_dir: str) -> dict[str, str]:
info_path = os.path.join(annotations_dir, "DALI_DATA_INFO.gz")
if not os.path.exists(info_path):
candidates = glob.glob(os.path.join(annotations_dir, "**", "DALI_DATA_INFO.gz"), recursive=True)
if candidates:
info_path = candidates[0]
else:
return {}
with gzip.open(info_path, "rb") as f:
info = pickle.load(f)
url_map: dict[str, str] = {}
if isinstance(info, dict):
for dali_id, entry in info.items():
yt_url = None
if isinstance(entry, dict):
for key in ("YOUTUBE", "youtube", "url", "URL"):
if key in entry:
yt_url = entry[key]
break
if yt_url:
if not yt_url.startswith("http"):
yt_url = f"https://www.youtube.com/watch?v={yt_url}"
url_map[str(dali_id)] = yt_url
return url_map
def merge_segments(segments: list[tuple[float, float]]) -> list[tuple[float, float]]:
if not segments:
return []
segments.sort()
merged: list[tuple[float, float]] = [segments[0]]
for start, end in segments[1:]:
if start <= merged[-1][1] + 0.05: merged[-1] = (merged[-1][0], max(merged[-1][1], end))
else:
merged.append((start, end))
return merged
def segments_to_manifest_label(
vocal_segments: list[tuple[float, float]], duration: float,
) -> str:
parts: list[str] = []
prev_end = 0.0
for start, end in vocal_segments:
if start > prev_end + 0.01:
parts.append(f"{prev_end:.1f}-{start:.1f}:non_vocal")
parts.append(f"{start:.1f}-{end:.1f}:vocal")
prev_end = end
if prev_end < duration - 0.01:
parts.append(f"{prev_end:.1f}-{duration:.1f}:non_vocal")
return ",".join(parts)
def download_audio(youtube_url: str, output_path: str) -> bool:
if os.path.exists(output_path):
return True
temp_path = output_path + ".tmp"
try:
subprocess.run(
[
"yt-dlp",
"--extract-audio",
"--audio-format", "wav",
"--postprocessor-args", "ffmpeg:-ar 44100 -ac 1 -sample_fmt s16",
"--output", f"{temp_path}.%(ext)s",
"--no-playlist",
"--quiet",
youtube_url,
],
check=True,
capture_output=True,
timeout=300,
)
temp_wav = f"{temp_path}.wav"
if os.path.exists(temp_wav):
os.rename(temp_wav, output_path)
return True
return False
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
for ext in (".wav", ".webm", ".m4a", ".opus"):
tmp = f"{temp_path}{ext}"
if os.path.exists(tmp):
os.remove(tmp)
return False
def get_wav_duration(wav_path: str) -> float:
try:
result = subprocess.run(
[
"ffprobe", "-v", "quiet",
"-show_entries", "format=duration",
"-of", "csv=p=0",
wav_path,
],
capture_output=True,
text=True,
timeout=30,
)
return float(result.stdout.strip())
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, ValueError):
return 0.0
def main() -> None:
parser = argparse.ArgumentParser(
description="Prepare DALI dataset manifest for vocal detector training"
)
parser.add_argument(
"--annotations-dir",
default=os.path.join(DATA_ROOT, "dali", "DALI_v1.0"),
help="Path to extracted DALI annotation .gz files",
)
parser.add_argument(
"--output-dir",
default=os.path.join(DATA_ROOT, "dali", "wavs"),
help="Directory to store downloaded WAV files",
)
parser.add_argument(
"--output",
default=os.path.join(DATA_ROOT, "dali_manifest.tsv"),
help="Output TSV manifest path",
)
parser.add_argument(
"--max-songs",
type=int,
default=0,
help="Limit number of songs to process (0 = all)",
)
parser.add_argument(
"--skip-download",
action="store_true",
help="Skip downloading, only process already-downloaded WAVs",
)
args = parser.parse_args()
if not args.skip_download:
if shutil.which("yt-dlp") is None:
print("ERROR: yt-dlp not found in PATH. Install with: pip install yt-dlp")
sys.exit(1)
if shutil.which("ffprobe") is None:
print("ERROR: ffprobe not found in PATH (comes with ffmpeg).")
sys.exit(1)
gz_files = sorted(glob.glob(os.path.join(args.annotations_dir, "*.gz")))
gz_files = [f for f in gz_files if "DALI_DATA_INFO" not in os.path.basename(f)]
if not gz_files:
gz_files = sorted(glob.glob(os.path.join(args.annotations_dir, "**", "*.gz"), recursive=True))
gz_files = [f for f in gz_files if "DALI_DATA_INFO" not in os.path.basename(f)]
print(f"Found {len(gz_files)} annotation files in {args.annotations_dir}")
if not gz_files:
print("ERROR: No annotation .gz files found.")
sys.exit(1)
url_map = load_dali_info(args.annotations_dir)
print(f"Loaded {len(url_map)} YouTube URLs from DALI_DATA_INFO.gz")
if not url_map and not args.skip_download:
print("WARNING: No YouTube URLs found. Will only process existing WAVs.")
args.skip_download = True
if args.max_songs > 0:
gz_files = gz_files[:args.max_songs]
print(f"Limiting to {args.max_songs} songs")
os.makedirs(args.output_dir, exist_ok=True)
manifest_entries: list[tuple[str, str, str]] = []
processed = 0
skipped_no_url = 0
skipped_download = 0
skipped_no_segments = 0
print(f"\nProcessing {len(gz_files)} songs...")
for i, gz_path in enumerate(gz_files):
dali_id = os.path.splitext(os.path.basename(gz_path))[0]
if dali_id.endswith(".gz"):
dali_id = dali_id[:-3]
print(f" [{i+1}/{len(gz_files)}] {dali_id}...", end=" ", flush=True)
segments = load_annotations(gz_path)
if not segments:
print("SKIP (no vocal segments)")
skipped_no_segments += 1
continue
wav_path = os.path.join(args.output_dir, f"{dali_id}.wav")
if not os.path.exists(wav_path):
if args.skip_download:
print("SKIP (no WAV)")
skipped_download += 1
continue
yt_url = url_map.get(dali_id)
if not yt_url:
print("SKIP (no YouTube URL)")
skipped_no_url += 1
continue
if not download_audio(yt_url, wav_path):
print("FAILED (download)")
skipped_download += 1
continue
duration = get_wav_duration(wav_path)
if duration <= 0:
print("SKIP (bad WAV)")
skipped_download += 1
continue
merged = merge_segments(segments)
segment_label = segments_to_manifest_label(merged, duration)
manifest_entries.append((wav_path, "segments", segment_label))
processed += 1
vocal_secs = sum(e - s for s, e in merged)
print(f"OK ({len(merged)} vocal regions, {vocal_secs:.0f}s / {duration:.0f}s)")
with open(args.output, "w", encoding="utf-8") as f:
for wav_path, label_type, label_value in manifest_entries:
f.write(f"{wav_path}\t{label_type}\t{label_value}\n")
print(f"\nManifest written to: {args.output}")
print(f" Songs processed: {processed}")
print(f" Skipped (no URL): {skipped_no_url}")
print(f" Skipped (download): {skipped_download}")
print(f" Skipped (no vocal): {skipped_no_segments}")
print(f" Manifest entries: {len(manifest_entries)}")
if __name__ == "__main__":
main()