av1an_core/
vapoursynth.rs

1use std::{
2    fmt::Display,
3    fs::{create_dir_all, File},
4    io::Write,
5    path::{absolute, Path, PathBuf},
6    process::Command,
7};
8
9use anyhow::{anyhow, bail, Context};
10use av_format::rational::Rational64;
11use path_abs::{PathAbs, PathInfo};
12use serde::{Deserialize, Serialize};
13use strum::{EnumString, IntoStaticStr};
14use tracing::info;
15use vapoursynth::{
16    core::CoreRef,
17    prelude::*,
18    video_info::{Resolution, VideoInfo},
19};
20
21use super::ChunkMethod;
22use crate::{
23    ffmpeg::FFPixelFormat,
24    metrics::{
25        butteraugli::ButteraugliSubMetric,
26        xpsnr::{weight_xpsnr, XPSNRSubMetric},
27    },
28    ClipInfo,
29    Input,
30    InputPixelFormat,
31};
32
33#[derive(
34    Serialize, PartialEq, Debug, Clone, Copy, EnumString, IntoStaticStr, Hash, Eq, Deserialize,
35)]
36pub enum CacheSource {
37    #[strum(serialize = "source")]
38    SOURCE,
39    #[strum(serialize = "temp")]
40    TEMP,
41}
42
43impl Display for CacheSource {
44    #[inline]
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.write_str(<&'static str>::from(self))
47    }
48}
49
50/// Contains a list of installed Vapoursynth plugins which may be used by av1an
51#[derive(Debug, Clone, Copy)]
52pub struct VapoursynthPlugins {
53    pub lsmash:     bool,
54    pub ffms2:      bool,
55    pub dgdecnv:    bool,
56    pub bestsource: bool,
57    pub julek:      bool,
58    pub vszip:      VSZipVersion,
59    pub vship:      bool,
60}
61
62impl VapoursynthPlugins {
63    #[inline]
64    pub fn best_available_chunk_method(&self) -> ChunkMethod {
65        if self.bestsource {
66            ChunkMethod::BESTSOURCE
67        } else if self.lsmash {
68            ChunkMethod::LSMASH
69        } else if self.ffms2 {
70            ChunkMethod::FFMS2
71        } else if self.dgdecnv {
72            ChunkMethod::DGDECNV
73        } else {
74            ChunkMethod::Hybrid
75        }
76    }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum VSZipVersion {
81    /// R7 or newer, has XPSNR and API changes
82    New,
83    /// prior to R7
84    Legacy,
85    /// not installed
86    None,
87}
88
89#[inline]
90pub fn get_vapoursynth_plugins() -> anyhow::Result<VapoursynthPlugins> {
91    let env = Environment::new().expect("Failed to initialize VapourSynth environment");
92    let core = env.get_core().expect("Failed to get VapourSynth core");
93
94    Ok(VapoursynthPlugins {
95        lsmash:     core.get_plugin_by_id(PluginId::Lsmash.as_str())?.is_some(),
96        ffms2:      core.get_plugin_by_id(PluginId::Ffms2.as_str())?.is_some(),
97        dgdecnv:    core.get_plugin_by_id(PluginId::DGDecNV.as_str())?.is_some(),
98        bestsource: core.get_plugin_by_id(PluginId::BestSource.as_str())?.is_some(),
99        julek:      core.get_plugin_by_id(PluginId::Julek.as_str())?.is_some(),
100        vszip:      if let Some(plugin) = core.get_plugin_by_id(PluginId::Vszip.as_str())? {
101            if is_vszip_r7_or_newer(plugin)? {
102                VSZipVersion::New
103            } else {
104                VSZipVersion::Legacy
105            }
106        } else {
107            VSZipVersion::None
108        },
109        vship:      core.get_plugin_by_id(PluginId::Vship.as_str())?.is_some(),
110    })
111}
112
113// There is no way to get the version of a plugin
114// so check for a function signature instead
115fn is_vszip_r7_or_newer(plugin: Plugin) -> anyhow::Result<bool> {
116    // R7 adds XPSNR and also introduces breaking changes the API
117    Ok(plugin.get_plugin_function_by_name("XPSNR")?.is_some())
118}
119
120#[inline]
121pub fn get_clip_info(source: &Input, vspipe_args_map: &OwnedMap) -> anyhow::Result<ClipInfo> {
122    const CONTEXT_MSG: &str = "get_clip_info";
123    const OUTPUT_INDEX: i32 = 0;
124
125    let mut environment = Environment::new().context(CONTEXT_MSG)?;
126    if environment.set_variables(vspipe_args_map).is_err() {
127        bail!("Failed to set vspipe arguments");
128    };
129    if source.is_vapoursynth() {
130        environment
131            .eval_file(source.as_path(), EvalFlags::SetWorkingDir)
132            .context(CONTEXT_MSG)?;
133    } else {
134        environment
135            .eval_script(&source.as_script_text(None, None, None)?)
136            .context(CONTEXT_MSG)?;
137    }
138
139    let (node, _) = environment.get_output(OUTPUT_INDEX)?;
140    let info = node.info();
141
142    Ok(ClipInfo {
143        num_frames:               get_num_frames(&info)?,
144        format_info:              InputPixelFormat::VapourSynth {
145            bit_depth: get_bit_depth(&info)?,
146        },
147        frame_rate:               get_frame_rate(&info)?,
148        resolution:               get_resolution(&info)?,
149        transfer_characteristics: match get_transfer(&environment)? {
150            16 => av1_grain::TransferFunction::SMPTE2084,
151            _ => av1_grain::TransferFunction::BT1886,
152        },
153    })
154}
155
156/// Get the number of frames from an environment that has already been
157/// evaluated on a script.
158fn get_num_frames(info: &VideoInfo) -> anyhow::Result<usize> {
159    let num_frames = {
160        if Property::Variable == info.resolution {
161            bail!("Cannot output clips with varying dimensions");
162        }
163        if Property::Variable == info.framerate {
164            bail!("Cannot output clips with varying framerate");
165        }
166
167        info.num_frames
168    };
169
170    assert!(num_frames != 0, "vapoursynth reported 0 frames");
171
172    Ok(num_frames)
173}
174
175fn get_frame_rate(info: &VideoInfo) -> anyhow::Result<Rational64> {
176    match info.framerate {
177        Property::Variable => bail!("Cannot output clips with varying framerate"),
178        Property::Constant(fps) => Ok(Rational64::new(
179            fps.numerator as i64,
180            fps.denominator as i64,
181        )),
182    }
183}
184
185/// Get the bit depth from an environment that has already been
186/// evaluated on a script.
187fn get_bit_depth(info: &VideoInfo) -> anyhow::Result<usize> {
188    let bits_per_sample = info.format.bits_per_sample();
189
190    Ok(bits_per_sample as usize)
191}
192
193/// Get the resolution from an environment that has already been
194/// evaluated on a script.
195fn get_resolution(info: &VideoInfo) -> anyhow::Result<(u32, u32)> {
196    let resolution = {
197        match info.resolution {
198            Property::Variable => {
199                bail!("Cannot output clips with variable resolution");
200            },
201            Property::Constant(x) => x,
202        }
203    };
204
205    Ok((resolution.width as u32, resolution.height as u32))
206}
207
208/// Get the transfer characteristics from an environment that has already
209/// been evaluated on a script.
210fn get_transfer(env: &Environment) -> anyhow::Result<u8> {
211    // Get the output node.
212    const OUTPUT_INDEX: i32 = 0;
213
214    let (node, _) = env.get_output(OUTPUT_INDEX)?;
215    let frame = node.get_frame(0).context("get_transfer")?;
216    let transfer = frame.props().get::<i64>("_Transfer").map(|val| val as u8).unwrap_or(2);
217
218    Ok(transfer)
219}
220
221#[derive(Debug, Copy, Clone, PartialEq, Eq)]
222enum PluginId {
223    Std,
224    Resize,
225    Lsmash,
226    Ffms2,
227    BestSource,
228    DGDecNV,
229    Julek,
230    Vszip,
231    Vship,
232}
233
234impl PluginId {
235    const fn as_str(self) -> &'static str {
236        match self {
237            PluginId::Std => "com.vapoursynth.std",
238            PluginId::Resize => "com.vapoursynth.resize",
239            PluginId::Lsmash => "systems.innocent.lsmas",
240            PluginId::Ffms2 => "com.vapoursynth.ffms2",
241            PluginId::BestSource => "com.vapoursynth.bestsource",
242            PluginId::DGDecNV => "com.vapoursynth.dgdecodenv",
243            PluginId::Julek => "com.julek.plugin",
244            PluginId::Vszip => "com.julek.vszip",
245            PluginId::Vship => "com.lumen.vship",
246        }
247    }
248}
249
250fn get_plugin(core: CoreRef, plugin_id: PluginId) -> anyhow::Result<Plugin> {
251    let plugin = core.get_plugin_by_id(plugin_id.as_str())?;
252
253    plugin.ok_or_else(|| {
254        anyhow::anyhow!(
255            "Failed to get VapourSynth {plugin_id} plugin",
256            plugin_id = plugin_id.as_str()
257        )
258    })
259}
260
261fn import_lsmash<'core>(
262    core: CoreRef<'core>,
263    encoded: &Path,
264    cache: Option<bool>,
265) -> anyhow::Result<Node<'core>> {
266    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
267    let lsmash = get_plugin(core, PluginId::Lsmash)?;
268    let absolute_encoded_path = absolute(encoded)?;
269
270    let mut arguments = vapoursynth::map::OwnedMap::new(api);
271    arguments.set(
272        "source",
273        &absolute_encoded_path.as_os_str().as_encoded_bytes(),
274    )?;
275    // Enable cache by default.
276    if let Some(cache) = cache {
277        arguments.set_int("cache", if cache { 1 } else { 0 })?;
278    }
279    // Allow hardware acceleration, falls back to software decoding.
280    arguments.set_int("prefer_hw", 3)?;
281
282    let error_message = format!(
283        "Failed to import {video_path} with lsmash",
284        video_path = encoded.display()
285    );
286
287    lsmash
288        .invoke("LWLibavSource", &arguments)
289        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
290        .get_video_node("clip")
291        .map_err(|_| anyhow::anyhow!(error_message.clone()))
292}
293
294fn import_ffms2<'core>(
295    core: CoreRef<'core>,
296    encoded: &Path,
297    cache: Option<bool>,
298) -> anyhow::Result<Node<'core>> {
299    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
300    let ffms2 = get_plugin(core, PluginId::Ffms2)?;
301    let absolute_encoded_path = absolute(encoded)?;
302
303    let mut arguments = vapoursynth::map::OwnedMap::new(api);
304    arguments.set(
305        "source",
306        &absolute_encoded_path.as_os_str().as_encoded_bytes(),
307    )?;
308
309    // Enable cache by default.
310    if let Some(cache) = cache {
311        arguments.set_int("cache", if cache { 1 } else { 0 })?;
312    }
313
314    let error_message = format!(
315        "Failed to import {video_path} with ffms2",
316        video_path = encoded.display()
317    );
318
319    ffms2
320        .invoke("Source", &arguments)
321        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
322        .get_video_node("clip")
323        .map_err(|_| anyhow::anyhow!(error_message.clone()))
324}
325
326fn import_bestsource<'core>(
327    core: CoreRef<'core>,
328    encoded: &Path,
329    cache: Option<bool>,
330) -> anyhow::Result<Node<'core>> {
331    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
332    let bestsource = get_plugin(core, PluginId::BestSource)?;
333    let absolute_encoded_path = absolute(encoded)?;
334
335    let mut arguments = vapoursynth::map::OwnedMap::new(api);
336    arguments.set(
337        "source",
338        &absolute_encoded_path.as_os_str().as_encoded_bytes(),
339    )?;
340
341    // Enable cache by default.
342    // Always try to read index but only write index to disk when it will make a
343    // noticeable difference on subsequent runs and store index files in the
344    // absolute path in *cachepath* with track number and index extension
345    // appended
346    if let Some(cache) = cache {
347        arguments.set_int("cachemode", if cache { 3 } else { 0 })?;
348    }
349
350    let error_message = format!(
351        "Failed to import {video_path} with bestsource",
352        video_path = encoded.display()
353    );
354
355    bestsource
356        .invoke("VideoSource", &arguments)
357        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
358        .get_video_node("clip")
359        .map_err(|_| anyhow::anyhow!(error_message.clone()))
360}
361
362// Attempts to import video using FFMS2, BestSource, or LSMASH in that order
363fn import_video<'core>(
364    core: CoreRef<'core>,
365    encoded: &Path,
366    cache: Option<bool>,
367) -> anyhow::Result<Node<'core>> {
368    import_ffms2(core, encoded, cache)
369        .or_else(|_| {
370            import_bestsource(core, encoded, cache).or_else(|_| import_lsmash(core, encoded, cache))
371        })
372        .map_err(|_| {
373            anyhow::anyhow!(
374                "Failed to import {video_path} with any decoder",
375                video_path = encoded.display()
376            )
377        })
378}
379
380fn trim_node<'core>(
381    core: CoreRef<'core>,
382    node: &Node<'core>,
383    start: u32,
384    end: u32,
385) -> anyhow::Result<Node<'core>> {
386    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
387    let std = get_plugin(core, PluginId::Std)?;
388
389    let mut arguments = vapoursynth::map::OwnedMap::new(api);
390    arguments.set("clip", node)?;
391    arguments.set("first", &(start as i64))?;
392    arguments.set("last", &(end as i64))?;
393
394    let error_message = format!("Failed to trim video from {start} to {end}");
395
396    std.invoke("Trim", &arguments)
397        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
398        .get_video_node("clip")
399        .map_err(|_| anyhow::anyhow!(error_message.clone()))
400}
401
402#[inline]
403pub fn resize_node<'core>(
404    core: CoreRef<'core>,
405    node: &Node<'core>,
406    width: Option<u32>,
407    height: Option<u32>,
408    format: Option<PresetFormat>,
409    matrix_in_s: Option<&'static str>,
410) -> anyhow::Result<Node<'core>> {
411    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
412    let std = get_plugin(core, PluginId::Resize)?;
413
414    let mut arguments = vapoursynth::map::OwnedMap::new(api);
415    arguments.set("clip", node)?;
416    if let Some(width) = width {
417        arguments.set_int("width", width as i64)?;
418    }
419    if let Some(height) = height {
420        arguments.set_int("height", height as i64)?;
421    }
422    if let Some(format) = format {
423        arguments.set_int("format", format as i64)?;
424    }
425    if let Some(matrix_in_s) = matrix_in_s {
426        arguments.set("matrix_in_s", &matrix_in_s.as_bytes())?;
427    }
428
429    let error_message = format!(
430        "Failed to resize video to {width}x{height}",
431        width = width.unwrap_or(0),
432        height = height.unwrap_or(0)
433    );
434
435    std.invoke("Bicubic", &arguments)
436        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
437        .get_video_node("clip")
438        .map_err(|_| anyhow::anyhow!(error_message.clone()))
439}
440
441fn select_every<'core>(
442    core: CoreRef<'core>,
443    node: &Node<'core>,
444    n: usize,
445) -> anyhow::Result<Node<'core>> {
446    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
447    let std = get_plugin(core, PluginId::Std)?;
448
449    let mut arguments = vapoursynth::map::OwnedMap::new(api);
450    arguments.set("clip", node)?;
451    arguments.set_int("cycle", n as i64)?;
452    arguments.set_int_array("offsets", &[0])?;
453
454    let error_message = format!("Failed to select 1 of every {n} frames");
455
456    std.invoke("SelectEvery", &arguments)
457        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
458        .get_video_node("clip")
459        .map_err(|_| anyhow::anyhow!(error_message.clone()))
460}
461
462fn compare_ssimulacra2<'core>(
463    core: CoreRef<'core>,
464    source: &Node<'core>,
465    encoded: &Node<'core>,
466    plugins: VapoursynthPlugins,
467) -> anyhow::Result<(Node<'core>, &'static str)> {
468    if !plugins.vship && plugins.vszip == VSZipVersion::None {
469        return Err(anyhow::anyhow!("SSIMULACRA2 not available"));
470    }
471
472    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
473    let plugin = get_plugin(
474        core,
475        if plugins.vship {
476            PluginId::Vship
477        } else {
478            PluginId::Vszip
479        },
480    )?;
481
482    let error_message = format!(
483        "Failed to calculate SSIMULACRA2 with {plugin_id} plugin",
484        plugin_id = if plugins.vship {
485            PluginId::Vship.as_str()
486        } else {
487            PluginId::Vszip.as_str()
488        }
489    );
490
491    let mut arguments = vapoursynth::map::OwnedMap::new(api);
492    arguments.set("reference", source)?;
493    arguments.set("distorted", encoded)?;
494
495    if plugins.vship {
496        arguments.set_int("numStream", 4)?;
497    } else if plugins.vszip == VSZipVersion::Legacy {
498        // Handle older vszip API
499        arguments.set_int("mode", 0)?;
500    }
501
502    let output = plugin
503        .invoke(
504            if plugins.vship || plugins.vszip == VSZipVersion::New {
505                "SSIMULACRA2"
506            } else {
507                // Handle older vszip API
508                "Metrics"
509            },
510            &arguments,
511        )
512        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
513        .get_video_node("clip")
514        .map_err(|_| anyhow::anyhow!(error_message.clone()))?;
515
516    Ok((
517        output,
518        if plugins.vship || plugins.vszip == VSZipVersion::Legacy {
519            "_SSIMULACRA2"
520        } else {
521            // Handle newer vszip API
522            "SSIMULACRA2"
523        },
524    ))
525}
526
527fn compare_butteraugli<'core>(
528    core: CoreRef<'core>,
529    source: &Node<'core>,
530    encoded: &Node<'core>,
531    submetric: ButteraugliSubMetric,
532    plugins: VapoursynthPlugins,
533) -> anyhow::Result<(Node<'core>, &'static str)> {
534    if !plugins.vship && !plugins.julek {
535        return Err(anyhow::anyhow!("butteraugli not available"));
536    }
537
538    const INTENSITY: f64 = 203.0;
539    let error_message = format!(
540        "Failed to calculate butteraugli with {plugin_id} plugin",
541        plugin_id = if plugins.vship {
542            PluginId::Vship.as_str()
543        } else {
544            PluginId::Julek.as_str()
545        }
546    );
547
548    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
549    let plugin = get_plugin(
550        core,
551        if plugins.vship {
552            PluginId::Vship
553        } else {
554            PluginId::Julek
555        },
556    )?;
557
558    let mut arguments = vapoursynth::map::OwnedMap::new(api);
559    arguments.set_int("distmap", 1)?;
560
561    if plugins.vship {
562        arguments.set("reference", source)?;
563        arguments.set("distorted", encoded)?;
564        arguments.set_float("intensity_multiplier", INTENSITY)?;
565        arguments.set_int("numStream", 4)?;
566    } else if plugins.julek {
567        // Inputs must be in RGBS format
568        let formatted_source = resize_node(
569            core,
570            source,
571            None,
572            None,
573            Some(PresetFormat::RGBS),
574            Some("709"),
575        )?;
576        let formatted_encoded = resize_node(
577            core,
578            encoded,
579            None,
580            None,
581            Some(PresetFormat::RGBS),
582            Some("709"),
583        )?;
584
585        arguments.set("reference", &formatted_source)?;
586        arguments.set("distorted", &formatted_encoded)?;
587        arguments.set_float("intensity_target", INTENSITY)?;
588    }
589
590    let output = plugin
591        .invoke(
592            if plugins.vship {
593                "BUTTERAUGLI"
594            } else {
595                "butteraugli"
596            },
597            &arguments,
598        )
599        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
600        .get_video_node("clip")
601        .map_err(|_| anyhow::anyhow!(error_message.clone()))?;
602
603    Ok((
604        output,
605        if plugins.vship {
606            if submetric == ButteraugliSubMetric::InfiniteNorm {
607                "_BUTTERAUGLI_INFNorm"
608            } else {
609                "_BUTTERAUGLI_3Norm"
610            }
611        } else {
612            "_FrameButteraugli"
613        },
614    ))
615}
616
617fn compare_xpsnr<'core>(
618    core: CoreRef<'core>,
619    source: &Node<'core>,
620    encoded: &Node<'core>,
621    plugins: VapoursynthPlugins,
622) -> anyhow::Result<Node<'core>> {
623    let api = API::get().ok_or_else(|| anyhow::anyhow!("Failed to get VapourSynth API"))?;
624
625    if plugins.vszip != VSZipVersion::New {
626        return Err(anyhow::anyhow!("XPSNR not available"));
627    }
628
629    let plugin = get_plugin(core, PluginId::Vszip)?;
630
631    // XPSNR requires YUV input and a maximum bit depth of 10
632    let formatted_source = resize_node(
633        core,
634        source,
635        None,
636        None,
637        Some(PresetFormat::YUV444P10),
638        None,
639    )?;
640    let formatted_encoded = resize_node(
641        core,
642        encoded,
643        None,
644        None,
645        Some(PresetFormat::YUV444P10),
646        None,
647    )?;
648
649    let mut arguments = vapoursynth::map::OwnedMap::new(api);
650    arguments.set("reference", &formatted_source)?;
651    arguments.set("distorted", &formatted_encoded)?;
652
653    let error_message = format!(
654        "Failed to calculate XPSNR with {plugin_id} plugin",
655        plugin_id = PluginId::Vszip.as_str()
656    );
657
658    plugin
659        .invoke("XPSNR", &arguments)
660        .map_err(|_| anyhow::anyhow!(error_message.clone()))?
661        .get_video_node("clip")
662        .map_err(|_| anyhow::anyhow!(error_message.clone()))
663}
664
665#[inline]
666pub fn create_vs_file(loadscript_args: &LoadscriptArgs) -> anyhow::Result<(PathBuf, bool)> {
667    let (load_script_text, cache_file_already_exists) =
668        generate_loadscript_text(&LoadscriptArgs {
669            temp:                             loadscript_args.temp,
670            source:                           loadscript_args.source,
671            chunk_method:                     loadscript_args.chunk_method,
672            scene_detection_downscale_height: loadscript_args.scene_detection_downscale_height,
673            scene_detection_pixel_format:     loadscript_args.scene_detection_pixel_format,
674            scene_detection_scaler:           loadscript_args.scene_detection_scaler,
675            is_proxy:                         loadscript_args.is_proxy,
676            cache_mode:                       loadscript_args.cache_mode,
677        })?;
678    // Ensure the temp folder exists
679    let temp: &Path = loadscript_args.temp.as_ref();
680    let split_folder = temp.join("split");
681    create_dir_all(&split_folder)?;
682
683    if loadscript_args.chunk_method == ChunkMethod::DGDECNV {
684        let absolute_source = absolute(loadscript_args.source)?;
685        let dgindexnv_output = split_folder.join(if loadscript_args.is_proxy {
686            "index_proxy.dgi"
687        } else {
688            "index.dgi"
689        });
690
691        if !dgindexnv_output.exists() {
692            info!("Indexing input with DGDecNV");
693
694            // Run dgindexnv to generate the .dgi index file
695            Command::new("dgindexnv")
696                .arg("-h")
697                .arg("-i")
698                .arg(&absolute_source)
699                .arg("-o")
700                .arg(&dgindexnv_output)
701                .output()?;
702        }
703    }
704
705    let load_script_path = split_folder.join(if loadscript_args.is_proxy {
706        "loadscript_proxy.vpy"
707    } else {
708        "loadscript.vpy"
709    });
710    let mut load_script = File::create(&load_script_path)?;
711
712    load_script.write_all(load_script_text.as_bytes())?;
713
714    Ok((load_script_path, cache_file_already_exists))
715}
716
717pub struct LoadscriptArgs<'a> {
718    pub temp:                             &'a str,
719    pub source:                           &'a Path,
720    pub chunk_method:                     ChunkMethod,
721    pub scene_detection_downscale_height: Option<usize>,
722    pub scene_detection_pixel_format:     Option<FFPixelFormat>,
723    pub scene_detection_scaler:           &'a str,
724    pub is_proxy:                         bool,
725    pub cache_mode:                       CacheSource,
726}
727
728#[inline]
729pub fn generate_loadscript_text(
730    loadscript_args: &LoadscriptArgs,
731) -> anyhow::Result<(String, bool)> {
732    let temp: &Path = loadscript_args.temp.as_ref();
733    let source = absolute(loadscript_args.source)?;
734
735    let cache_file = PathAbs::new(temp.join("split").join(format!(
736        "{}cache.{}",
737        if loadscript_args.is_proxy {
738            "proxy_"
739        } else {
740            ""
741        },
742        match loadscript_args.chunk_method {
743            ChunkMethod::FFMS2 => "ffindex",
744            ChunkMethod::LSMASH => "lwi",
745            ChunkMethod::DGDECNV => "dgi",
746            ChunkMethod::BESTSOURCE => "bsindex",
747            _ => return Err(anyhow!("invalid chunk method")),
748        }
749    )))?;
750    let chunk_method_lower = match loadscript_args.chunk_method {
751        ChunkMethod::FFMS2 => "ffms2",
752        ChunkMethod::LSMASH => "lsmash",
753        ChunkMethod::DGDECNV => "dgdecnv",
754        ChunkMethod::BESTSOURCE => "bestsource",
755        _ => return Err(anyhow!("invalid chunk method")),
756    };
757
758    // Only used for DGDECNV
759    let dgindex_path = match loadscript_args.chunk_method {
760        ChunkMethod::DGDECNV => {
761            let dgindexnv_output = temp.join("split").join(if loadscript_args.is_proxy {
762                "index_proxy.dgi"
763            } else {
764                "index.dgi"
765            });
766            &absolute(&dgindexnv_output)?
767        },
768        _ => &source,
769    };
770
771    // Include rich loadscript.vpy and specify source, chunk_method, and cache_file
772    // Also specify downscale_height, pixel_format, and scaler for Scene Detection
773    // TODO should probably check if the syntax for rust strings and escaping utf
774    // and stuff like that is the same as in python
775    let mut load_script_text = include_str!("loadscript.vpy")
776        .replace(
777            "source = os.environ.get(\"AV1AN_SOURCE\", None)",
778            &format!("source = r\"{}\"", match loadscript_args.chunk_method {
779                ChunkMethod::DGDECNV => dgindex_path.display(),
780                _ => source.display(),
781            }),
782        )
783        .replace(
784            "chunk_method = os.environ.get(\"AV1AN_CHUNK_METHOD\", None)",
785            &format!("chunk_method = {chunk_method_lower:?}"),
786        );
787
788    if let Some(scene_detection_downscale_height) = loadscript_args.scene_detection_downscale_height
789    {
790        load_script_text = load_script_text.replace(
791            "downscale_height = os.environ.get(\"AV1AN_DOWNSCALE_HEIGHT\", None)",
792            &format!(
793                "downscale_height = os.environ.get(\"AV1AN_DOWNSCALE_HEIGHT\", \
794                 {scene_detection_downscale_height})"
795            ),
796        );
797    }
798    if let Some(scene_detection_pixel_format) = loadscript_args.scene_detection_pixel_format {
799        load_script_text = load_script_text.replace(
800            "sc_pix_format = os.environ.get(\"AV1AN_PIXEL_FORMAT\", None)",
801            &format!(
802                "pixel_format = os.environ.get(\"AV1AN_PIXEL_FORMAT\", \
803                 \"{scene_detection_pixel_format:?}\")"
804            ),
805        );
806    }
807    if loadscript_args.cache_mode == CacheSource::TEMP {
808        load_script_text = load_script_text.replace(
809            "cache_file = os.environ.get(\"AV1AN_CACHE_FILE\", None)",
810            &format!(
811                "cache_file = r\"{}\"",
812                dunce::simplified(cache_file.as_path()).display(),
813            ),
814        );
815    }
816
817    load_script_text = load_script_text.replace(
818        "cache_mode = os.environ.get(\"AV1AN_CACHE_MODE\", None)",
819        &format!("cache_mode = \"{}\"", loadscript_args.cache_mode),
820    );
821
822    let scene_detection_scaler = loadscript_args.scene_detection_scaler;
823    load_script_text = load_script_text.replace(
824        "scaler = os.environ.get(\"AV1AN_SCALER\", None)",
825        &format!("scaler = os.environ.get(\"AV1AN_SCALER\", {scene_detection_scaler:?})"),
826    );
827
828    let cache_file_already_exists = match loadscript_args.chunk_method {
829        ChunkMethod::DGDECNV => dgindex_path.exists(),
830        _ => cache_file.exists(),
831    };
832
833    Ok((load_script_text, cache_file_already_exists))
834}
835
836#[inline]
837pub fn get_source_chunk<'core>(
838    core: CoreRef<'core>,
839    source_node: &Node<'core>,
840    frame_range: (u32, u32),
841    probe_res: Option<(u32, u32)>,
842    sample_rate: usize,
843) -> anyhow::Result<Node<'core>> {
844    let mut chunk_node = trim_node(core, source_node, frame_range.0, frame_range.1 - 1)?;
845
846    if let Some((width, height)) = probe_res {
847        chunk_node = resize_node(core, &chunk_node, Some(width), Some(height), None, None)?;
848    }
849
850    if sample_rate > 1 {
851        chunk_node = select_every(core, &chunk_node, sample_rate)?;
852    }
853
854    Ok(chunk_node)
855}
856
857#[inline]
858pub fn get_comparands<'core>(
859    core: CoreRef<'core>,
860    source_node: &Node<'core>,
861    encoded: &Path,
862    frame_range: (u32, u32),
863    probe_res: Option<(u32, u32)>,
864    sample_rate: usize,
865) -> anyhow::Result<(Node<'core>, Node<'core>)> {
866    let chunk_node = get_source_chunk(core, source_node, frame_range, probe_res, sample_rate)?;
867    let encoded_node = import_video(core, encoded, Some(false))?;
868    let resized_encoded_node = if let Some((width, height)) = probe_res {
869        resize_node(core, &encoded_node, Some(width), Some(height), None, None)?
870    } else {
871        let chunk_node_resolution = chunk_node.info().resolution;
872        let (width, height) = match chunk_node_resolution {
873            Property::Variable => (0, 0),
874            Property::Constant(Resolution {
875                width,
876                height,
877            }) => (width as u32, height as u32),
878        };
879        resize_node(core, &encoded_node, Some(width), Some(height), None, None)?
880    };
881
882    Ok((chunk_node, resized_encoded_node))
883}
884
885#[inline]
886pub fn measure_butteraugli(
887    submetric: ButteraugliSubMetric,
888    source: &Input,
889    encoded: &Path,
890    frame_range: (u32, u32),
891    probe_res: Option<(u32, u32)>,
892    sample_rate: usize,
893    plugins: VapoursynthPlugins,
894) -> anyhow::Result<Vec<f64>> {
895    let mut environment = Environment::new()?;
896    let args = source.as_vspipe_args_map()?;
897    environment.set_variables(&args)?;
898    // Cannot use eval_file because it causes file system access errors during
899    // Target Quality probing
900    // Consider using eval_file only when source is not in CWD
901    environment.eval_script(&source.as_script_text(None, None, None)?)?;
902    let core = environment.get_core()?;
903
904    let source_node = environment.get_output(0)?.0;
905    let (chunk_node, encoded_node) = get_comparands(
906        core,
907        &source_node,
908        encoded,
909        frame_range,
910        probe_res,
911        sample_rate,
912    )?;
913    let (compared_node, butteraugli_key) =
914        compare_butteraugli(core, &chunk_node, &encoded_node, submetric, plugins)?;
915
916    let mut scores = Vec::new();
917    for frame_index in 0..compared_node.info().num_frames {
918        let score = compared_node.get_frame(frame_index)?.props().get_float(butteraugli_key)?;
919        scores.push(score);
920    }
921
922    Ok(scores)
923}
924
925#[inline]
926pub fn measure_ssimulacra2(
927    source: &Input,
928    encoded: &Path,
929    frame_range: (u32, u32),
930    probe_res: Option<(u32, u32)>,
931    sample_rate: usize,
932    plugins: VapoursynthPlugins,
933) -> anyhow::Result<Vec<f64>> {
934    let mut environment = Environment::new()?;
935    let args = source.as_vspipe_args_map()?;
936    environment.set_variables(&args)?;
937    // Cannot use eval_file because it causes file system access errors during
938    // Target Quality probing
939    environment.eval_script(&source.as_script_text(None, None, None)?)?;
940    let core = environment.get_core()?;
941
942    let source_node = environment.get_output(0)?.0;
943    let (chunk_node, encoded_node) = get_comparands(
944        core,
945        &source_node,
946        encoded,
947        frame_range,
948        probe_res,
949        sample_rate,
950    )?;
951    let (compared_node, ssimulacra_key) =
952        compare_ssimulacra2(core, &chunk_node, &encoded_node, plugins)?;
953
954    let mut scores = Vec::new();
955    for frame_index in 0..compared_node.info().num_frames {
956        let score = compared_node.get_frame(frame_index)?.props().get_float(ssimulacra_key)?;
957        scores.push(score);
958    }
959
960    Ok(scores)
961}
962
963#[inline]
964pub fn measure_xpsnr(
965    submetric: XPSNRSubMetric,
966    source: &Input,
967    encoded: &Path,
968    frame_range: (u32, u32),
969    probe_res: Option<(u32, u32)>,
970    sample_rate: usize,
971    plugins: VapoursynthPlugins,
972) -> anyhow::Result<Vec<f64>> {
973    let mut environment = Environment::new()?;
974    let args = source.as_vspipe_args_map()?;
975    environment.set_variables(&args)?;
976    // Cannot use eval_file because it causes file system access errors during
977    // Target Quality probing
978    environment.eval_script(&source.as_script_text(None, None, None)?)?;
979    let core = environment.get_core()?;
980
981    let source_node = environment.get_output(0)?.0;
982    let (chunk_node, encoded_node) = get_comparands(
983        core,
984        &source_node,
985        encoded,
986        frame_range,
987        probe_res,
988        sample_rate,
989    )?;
990    let compared_node = compare_xpsnr(core, &chunk_node, &encoded_node, plugins)?;
991
992    let mut scores = Vec::new();
993    for frame_index in 0..compared_node.info().num_frames {
994        let frame = compared_node.get_frame(frame_index)?;
995        let xpsnr_y = frame
996            .props()
997            .get_float("XPSNR_Y")
998            .or(Ok::<f64, std::convert::Infallible>(f64::INFINITY))?;
999        let xpsnr_u = frame
1000            .props()
1001            .get_float("XPSNR_U")
1002            .or(Ok::<f64, std::convert::Infallible>(f64::INFINITY))?;
1003        let xpsnr_v = frame
1004            .props()
1005            .get_float("XPSNR_V")
1006            .or(Ok::<f64, std::convert::Infallible>(f64::INFINITY))?;
1007
1008        match submetric {
1009            XPSNRSubMetric::Minimum => {
1010                let minimum = f64::min(xpsnr_y, f64::min(xpsnr_u, xpsnr_v));
1011                scores.push(minimum);
1012            },
1013            XPSNRSubMetric::Weighted => {
1014                let weighted = weight_xpsnr(xpsnr_y, xpsnr_u, xpsnr_v);
1015                scores.push(weighted);
1016            },
1017        }
1018    }
1019
1020    Ok(scores)
1021}