protoc-gen-ts-temporal 0.0.1

protoc plugin that emits a typed TypeScript Temporal client from temporal.v1.* annotated protos
Documentation
//! DescriptorPool → ServiceModel.
//!
//! The descriptor pool MUST have been built via
//! `DescriptorPool::decode_file_descriptor_set` so that extension bytes on
//! `MethodOptions` are preserved. See `bin/protoc-gen-ts-temporal`.

use std::collections::HashSet;

use anyhow::{Result, anyhow};
use prost::Message;
use prost_reflect::{
    DescriptorPool, DynamicMessage, ExtensionDescriptor, FileDescriptor, MethodDescriptor,
    ServiceDescriptor, Value,
};

use crate::model::*;
use crate::proto::temporal::v1 as cludden;

const WORKFLOW_EXT: &str = "temporal.v1.workflow";
const SIGNAL_EXT: &str = "temporal.v1.signal";
const QUERY_EXT: &str = "temporal.v1.query";
const UPDATE_EXT: &str = "temporal.v1.update";
const ACTIVITY_EXT: &str = "temporal.v1.activity";
const SERVICE_EXT: &str = "temporal.v1.service";

pub fn parse(
    pool: &DescriptorPool,
    files_to_generate: &HashSet<String>,
) -> Result<Vec<ServiceModel>> {
    let ctx = ExtensionCtx::resolve(pool)?;

    let mut out = Vec::new();
    for file in pool.files() {
        if !files_to_generate.contains(file.name()) {
            continue;
        }
        for service in file.services() {
            if let Some(model) = parse_service(&file, &service, &ctx)? {
                out.push(model);
            }
        }
    }
    Ok(out)
}

struct ExtensionCtx {
    workflow: ExtensionDescriptor,
    signal: ExtensionDescriptor,
    query: ExtensionDescriptor,
    update: ExtensionDescriptor,
    activity: ExtensionDescriptor,
    service: ExtensionDescriptor,
}

impl ExtensionCtx {
    fn resolve(pool: &DescriptorPool) -> Result<Self> {
        let need = |name: &str| {
            pool.get_extension_by_name(name)
                .ok_or_else(|| anyhow!("missing extension definition: {name}"))
        };
        Ok(Self {
            workflow: need(WORKFLOW_EXT)?,
            signal: need(SIGNAL_EXT)?,
            query: need(QUERY_EXT)?,
            update: need(UPDATE_EXT)?,
            activity: need(ACTIVITY_EXT)?,
            service: need(SERVICE_EXT)?,
        })
    }
}

fn parse_service(
    file: &FileDescriptor,
    service: &ServiceDescriptor,
    ctx: &ExtensionCtx,
) -> Result<Option<ServiceModel>> {
    let package = file.package_name().to_string();
    let service_name = service.name().to_string();
    let default_task_queue = parse_service_options(service, ctx)?
        .map(|so| so.task_queue)
        .unwrap_or_default();

    let mut workflows = Vec::new();
    let mut signals = Vec::new();
    let mut queries = Vec::new();
    let mut updates = Vec::new();
    let mut activities = Vec::new();

    for method in service.methods() {
        match method_kind(&method, ctx)? {
            MethodKind::Workflow(opts) => workflows.push(parse_workflow(&method, *opts)),
            MethodKind::Signal(opts) => signals.push(parse_signal(&method, opts)),
            MethodKind::Query(opts) => queries.push(parse_query(&method, opts)),
            MethodKind::Update(opts) => updates.push(parse_update(&method, *opts)),
            MethodKind::Activity(opts) => activities.push(parse_activity(&method, *opts)),
            MethodKind::None => {}
        }
    }

    if workflows.is_empty()
        && signals.is_empty()
        && queries.is_empty()
        && updates.is_empty()
        && activities.is_empty()
    {
        return Ok(None);
    }

    Ok(Some(ServiceModel {
        package,
        service: service_name,
        source_file: file.name().to_string(),
        default_task_queue,
        workflows,
        signals,
        queries,
        updates,
        activities,
    }))
}

fn parse_service_options(
    service: &ServiceDescriptor,
    ctx: &ExtensionCtx,
) -> Result<Option<cludden::ServiceOptions>> {
    let opts: DynamicMessage = service.options();
    if !opts.has_extension(&ctx.service) {
        return Ok(None);
    }
    Ok(Some(decode_ext(&opts, &ctx.service)?))
}

fn parse_workflow(method: &MethodDescriptor, opts: cludden::WorkflowOptions) -> WorkflowModel {
    WorkflowModel {
        rpc_method: method.name().to_string(),
        input_type: ProtoType::from_full_name(method.input().full_name()),
        output_type: ProtoType::from_full_name(method.output().full_name()),
        task_queue: opts.task_queue.clone(),
        registered_name: (!opts.name.is_empty()).then(|| opts.name.clone()),
        aliases: opts.aliases.clone(),
        attached_signals: opts
            .signal
            .into_iter()
            .map(|s| SignalRef {
                method: s.r#ref,
                signal_with_start: s.start,
            })
            .collect(),
        attached_queries: opts
            .query
            .into_iter()
            .map(|q| QueryRef { method: q.r#ref })
            .collect(),
        attached_updates: opts
            .update
            .into_iter()
            .map(|u| UpdateRef {
                method: u.r#ref,
                update_with_start: u.start,
                update_with_validation: u.validate,
            })
            .collect(),
    }
}

fn parse_signal(method: &MethodDescriptor, opts: cludden::SignalOptions) -> SignalModel {
    let registered_name = if opts.name.is_empty() {
        method.name().to_string()
    } else {
        opts.name
    };
    SignalModel {
        rpc_method: method.name().to_string(),
        input_type: ProtoType::from_full_name(method.input().full_name()),
        output_type: ProtoType::from_full_name(method.output().full_name()),
        registered_name,
    }
}

fn parse_query(method: &MethodDescriptor, opts: cludden::QueryOptions) -> QueryModel {
    let registered_name = if opts.name.is_empty() {
        method.name().to_string()
    } else {
        opts.name
    };
    QueryModel {
        rpc_method: method.name().to_string(),
        input_type: ProtoType::from_full_name(method.input().full_name()),
        output_type: ProtoType::from_full_name(method.output().full_name()),
        registered_name,
    }
}

fn parse_update(method: &MethodDescriptor, opts: cludden::UpdateOptions) -> UpdateModel {
    let registered_name = if opts.name.is_empty() {
        method.name().to_string()
    } else {
        opts.name
    };
    UpdateModel {
        rpc_method: method.name().to_string(),
        input_type: ProtoType::from_full_name(method.input().full_name()),
        output_type: ProtoType::from_full_name(method.output().full_name()),
        registered_name,
    }
}

fn parse_activity(method: &MethodDescriptor, opts: cludden::ActivityOptions) -> ActivityModel {
    let registered_name = if opts.name.is_empty() {
        method.name().to_string()
    } else {
        opts.name
    };
    ActivityModel {
        rpc_method: method.name().to_string(),
        registered_name,
    }
}

enum MethodKind {
    // Workflow boxed: WorkflowOptions is far larger than the other variants.
    Workflow(Box<cludden::WorkflowOptions>),
    Signal(cludden::SignalOptions),
    Query(cludden::QueryOptions),
    Update(Box<cludden::UpdateOptions>),
    Activity(Box<cludden::ActivityOptions>),
    None,
}

fn method_kind(method: &MethodDescriptor, ctx: &ExtensionCtx) -> Result<MethodKind> {
    let opts: DynamicMessage = method.options();
    if opts.has_extension(&ctx.workflow) {
        return Ok(MethodKind::Workflow(Box::new(decode_ext(
            &opts,
            &ctx.workflow,
        )?)));
    }
    if opts.has_extension(&ctx.signal) {
        return Ok(MethodKind::Signal(decode_ext(&opts, &ctx.signal)?));
    }
    if opts.has_extension(&ctx.query) {
        return Ok(MethodKind::Query(decode_ext(&opts, &ctx.query)?));
    }
    if opts.has_extension(&ctx.update) {
        return Ok(MethodKind::Update(Box::new(decode_ext(
            &opts,
            &ctx.update,
        )?)));
    }
    if opts.has_extension(&ctx.activity) {
        return Ok(MethodKind::Activity(Box::new(decode_ext(
            &opts,
            &ctx.activity,
        )?)));
    }
    Ok(MethodKind::None)
}

/// Re-encode a `DynamicMessage` extension value back to wire bytes and decode
/// it into the prost-generated typed message. This is the bridge from the
/// reflection layer (which preserves extensions) to the typed layer (which is
/// nicer to read).
fn decode_ext<M: Message + Default>(opts: &DynamicMessage, ext: &ExtensionDescriptor) -> Result<M> {
    let value = opts.get_extension(ext);
    match value.as_ref() {
        Value::Message(m) => Ok(M::decode(&*m.encode_to_vec())?),
        other => Err(anyhow!(
            "extension {} expected to be a message, got {:?}",
            ext.full_name(),
            other
        )),
    }
}